diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d775ef5a..648cf92b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -58,4 +58,4 @@ jobs: - name: Run Flake8 on changed files if: needs.changed-files.outputs.files != '' run: | - flake8 --ignore=E203,W503 --max-line-length 120 ${{ needs.changed-files.outputs.files }} + flake8 --ignore=E203,W503,E402 --max-line-length 120 ${{ needs.changed-files.outputs.files }} diff --git a/.gitignore b/.gitignore index 015f8185..5923ce20 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ .idea/ .vscode/ +*.CSV +*.csv + # Embeddings cache pipeline_cache # Byte-compiled / optimized / DLL files diff --git a/chatdku/config.py b/chatdku/config.py index 225bf7f9..a4f440a9 100644 --- a/chatdku/config.py +++ b/chatdku/config.py @@ -55,7 +55,7 @@ def _initialize_defaults(self): db_host = _env("DB_HOST") db_port = _env("DB_PORT") db_name = _env("DB_NAME") - psql_uri = f"postgresql://{db_user}:{quote_plus(db_password)}@{db_host}:{db_port}/{db_name}" + psql_uri = f"postgresql://{db_user}:{quote_plus(db_password or '')}@{db_host}:{db_port}/{db_name}" self._store.update( { @@ -117,7 +117,7 @@ def _initialize_defaults(self): # External data "prereq_csv_path": "/datapool/chatdku_external_data/DK_SR_PREREQ_CRSE_CHATDKU.csv", "classdata_csv_path": "/datapool/chatdku_external_data/cleaned_classdata.csv", - "major_requirements_dir": "/datapool/chatdku_external_data/doc_testing/output/ug_bulletin_2023-2024", + "major_req_dir": "/datapool/chatdku_external_data/doc_testing/output/ug_bulletin_2023-2024", } ) # refresh read-only view diff --git a/chatdku/core/agent.py b/chatdku/core/agent.py index cf6d6f06..5ea043cd 100755 --- a/chatdku/core/agent.py +++ b/chatdku/core/agent.py @@ -1,6 +1,13 @@ #!/usr/bin/env python3 +import argparse +import os +import sys import traceback +import pyfiglet +# Must be set before `import dspy` — prevents litellm from fetching the remote +# model pricing database at startup (cuts ~40s off cold-start time). +os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # noqa: E402,E401 import dspy from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry.trace import Status, StatusCode, use_span @@ -10,15 +17,16 @@ from chatdku.core.dspy_classes.executor import Executor from chatdku.core.dspy_classes.plan import Planner from chatdku.core.dspy_classes.synthesizer import Synthesizer -from chatdku.core.tools.course_schedule import CourseScheduleLookupOuter -from chatdku.core.tools.get_prerequisites import PrerequisiteLookupOuter +from chatdku.core.tools.course_recommender import CourseRecommender +from chatdku.core.tools.course_schedule import CourseScheduleLookup +from chatdku.core.tools.get_prerequisites import PrerequisiteLookup from chatdku.core.tools.llama_index_tools import ( KeywordRetrieverOuter, VectorRetrieverOuter, ) -from chatdku.core.tools.major_requirements import MajorRequirementsLookupOuter -from chatdku.core.tools.syllabi_tool.query_curriculum_db import QueryCurriculumOuter -from chatdku.core.utils import format_trajectory, load_conversation, span_start +from chatdku.core.tools.major_requirements import MajorRequirementsLookup +from chatdku.core.tools.syllabi.syllabi_tool import SyllabusLookupOuter +from chatdku.core.utils import load_conversation, span_start from chatdku.setup import setup, use_phoenix # When `--dev` is passed to the script, enable additional debug prints in this module. @@ -99,18 +107,6 @@ def _forward_gen( ) with use_span(span): - # Putting this in `self.__init__()` might not work due to that you might - # want DSPy to change prompt dynamically. - - # These limits are for compressing both tool and conversation memory. - # Uses the executor's token limits as the executor has the largest context needs. - limits = self.executor.get_token_limits( - plan="", - current_user_message=current_user_message, - conversation_history=self.conversation_memory.history_str(), - trajectory=format_trajectory({}), - ) - # Clear internal memory for each user message self.internal_memory.clear() @@ -125,7 +121,6 @@ def _forward_gen( self.conversation_memory( role="assistant", content=prev_response, - max_history_size=limits["conversation_history"], ) plan_result = self.planner( @@ -157,7 +152,6 @@ def _forward_gen( self.conversation_memory( role="user", content=current_user_message, - max_history_size=limits["conversation_history"], ) if not self.streaming: @@ -188,7 +182,8 @@ def forward( return i -def main(): +def build_agent(streaming: bool = True, max_iterations: int = 10) -> "Agent": + """Configure DSPy and return a ready-to-use Agent instance.""" setup() use_phoenix() @@ -210,17 +205,13 @@ def main(): enable_thinking=False, ) dspy.configure(lm=lm) - # To disable cache: + # To disable cache: # dspy.configure_cache( # enable_disk_cache=False, # enable_memory_cache=False # ) - import time - - # role = "student" - # access_type = "student" # hard code it for now, need parameter pass from user role user_id = "Chat_DKU" search_mode = 0 tools = [ @@ -240,29 +231,62 @@ def main(): search_mode=search_mode, files=[], ), - # DocRetrieverOuter( - # retriever_top_k=25, - # use_reranker=True, - # reranker_top_n=5, - # access_type=access_type, - # role=role, - # user_id=user_id, - # search_mode=search_mode, - # files=[], - # ), - MajorRequirementsLookupOuter(config.major_requirements_dir), - QueryCurriculumOuter(), - PrerequisiteLookupOuter(prereq_csv_path=config.prereq_csv_path), - CourseScheduleLookupOuter(classdata_csv_path=config.classdata_csv_path), + SyllabusLookupOuter(), + MajorRequirementsLookup, + PrerequisiteLookup, + CourseRecommender, + CourseScheduleLookup, ] - agent = Agent( - max_iterations=3, - streaming=True, + return Agent( + max_iterations=max_iterations, + streaming=streaming, get_intermediate=False, tools=tools, ) + +def run_query(query: str, agent: "Agent | None" = None) -> str: + """Run a single query and return the full response as a string. + + Suitable for programmatic use from Python: + from chatdku.core.agent import run_query + print(run_query("What are the CS major requirements?")) + """ + if agent is None: + agent = build_agent(streaming=False) + result = agent(current_user_message=query) + response = result.response + if isinstance(response, str): + return response + # Streaming generator — collect. + return "".join(response) + + +def main(): + parser = argparse.ArgumentParser(description="ChatDKU agent.") + parser.add_argument( + "query", + nargs="*", + help="Query to run once and exit. If omitted, starts interactive mode.", + ) + args = parser.parse_args() + + if args.query: + query = " ".join(args.query) + print(run_query(query)) + return + + _main_interactive() + + +def _main_interactive(): + import time + + agent = build_agent(streaming=True) + + pyfiglet.figlet_format("ChatDKU", font="slant") + while True: try: print("*" * 10) @@ -290,5 +314,5 @@ def main(): main() except Exception: print(traceback.format_exc()) - - input() + if sys.stdin.isatty(): + input() diff --git a/chatdku/core/ascii_logo b/chatdku/core/ascii_logo new file mode 100644 index 00000000..33dfe203 --- /dev/null +++ b/chatdku/core/ascii_logo @@ -0,0 +1,13 @@ +                              +                              +              **              +            .xxxx             +           -xxxxxx            +          =###*x###.          +         =###+  *###-         +        +###=    *###=        +      ---+x=      +x====      +     ------.      .======     +    --------.    =========.   +   ----------.  ===========.  + .------------.-============. diff --git a/chatdku/core/dspy_classes/conversation_memory.py b/chatdku/core/dspy_classes/conversation_memory.py index ae9e5148..ab694f59 100644 --- a/chatdku/core/dspy_classes/conversation_memory.py +++ b/chatdku/core/dspy_classes/conversation_memory.py @@ -1,6 +1,7 @@ -from typing import Optional +import json import dspy +from litellm.exceptions import ContextWindowExceededError from openinference.instrumentation import safe_json_dumps from openinference.semconv.trace import ( OpenInferenceMimeTypeValues, @@ -8,21 +9,12 @@ SpanAttributes, ) from opentelemetry.trace import Status, StatusCode -from pydantic import BaseModel, ConfigDict - -from chatdku.core.dspy_common import get_template -from chatdku.core.utils import ( - span_ctx_start, - strs_fit_max_tokens_reverse, - token_limit_ratio_to_count, - truncate_tokens_all, -) + +from chatdku.core.utils import span_ctx_start -class ConversationMemoryEntry(BaseModel): - model_config = ConfigDict(extra="forbid") - role: str - content: str +MAX_HISTORY_ENTRIES = 6 +TRUNCATE_BATCH_SIZE = 2 class CompressConversationMemorySignature(dspy.Signature): @@ -56,70 +48,32 @@ class ConversationMemory(dspy.Module): def __init__(self): super().__init__() self.compressor = dspy.Predict(CompressConversationMemorySignature) - self.history: list[ConversationMemoryEntry] = [] + self.history: list[dict] = [] self.summary: str = "" - self.token_ratios: dict[str, float] = { - "history_to_discard": 2 / 4, - "previous_summary": 1 / 4, - } - - def history_str(self, left: int = 0, right: Optional[int] = None): - if right is None: - right = len(self.history) - - return "\n".join( - [ - i.model_dump_json(indent=4) - for i in self.history[left:right] - if not isinstance(i, dict) - ] - ) - def get_token_limits(self, **kwargs) -> dict[str, int]: - return token_limit_ratio_to_count( - self.token_ratios, len(get_template(self.compressor, **kwargs)) - ) + def history_str(self) -> str: + return "\n".join(json.dumps(entry) for entry in self.history) - def forward(self, role: str, content: str, max_history_size: int = 1000): + def forward(self, role: str, content: str): with span_ctx_start( "Conversation Memory", OpenInferenceSpanKindValues.CHAIN ) as span: - new_entry = ConversationMemoryEntry(role=role, content=content) + new_entry = {role: content} span.set_attributes( { - SpanAttributes.INPUT_VALUE: safe_json_dumps(new_entry.model_dump()), + SpanAttributes.INPUT_VALUE: safe_json_dumps(new_entry), SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, } ) self.history.append(new_entry) - min_index = strs_fit_max_tokens_reverse( - [i.model_dump_json() for i in self.history if not isinstance(i, dict)], - "\n", - max_history_size, - ) - if min_index > 0: - compressor_inputs = dict( - history_to_discard=self.history_str(0, min_index), - previous_summary=self.summary, - ) - compressor_inputs = truncate_tokens_all( - compressor_inputs, self.get_token_limits(**compressor_inputs) - ) - self.summary = self.compressor(**compressor_inputs).current_summary - self.history = self.history[min_index:] + if len(self.history) > MAX_HISTORY_ENTRIES: + self._compress_oldest_with_retry(TRUNCATE_BATCH_SIZE) span.set_attributes( { SpanAttributes.OUTPUT_VALUE: safe_json_dumps( - dict( - history=[ - i.model_dump() - for i in self.history - if not isinstance(i, dict) - ], - summary=self.summary, - ) + dict(history=self.history, summary=self.summary) ), SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, } @@ -127,5 +81,36 @@ def forward(self, role: str, content: str, max_history_size: int = 1000): span.set_status(Status(StatusCode.OK)) def register_history(self, role: str, content: str): - new_entry = ConversationMemoryEntry(role=role, content=content) - self.history.append(new_entry) + self.history.append({role: content}) + + def _compress_oldest_with_retry(self, batch_size: int): + """Summarize the oldest `batch_size` entries into the running summary. + + On context window overflow, shrinks the batch one entry at a time and + retries (up to 3 attempts), mirroring the pattern in Executor. + """ + to_discard = self.history[:batch_size] + remaining = self.history[batch_size:] + + for _ in range(3): + try: + self.summary = self._summarize(to_discard, self.summary) + self.history = remaining + return + except ContextWindowExceededError: + if len(to_discard) <= 1: + raise ValueError( + "The conversation history exceeded the context window even with a single entry." + ) + self.summary = self._summarize([to_discard[0]], self.summary) + to_discard = to_discard[1:] + + raise ValueError( + "The context window was exceeded even after 3 attempts to truncate the conversation history." + ) + + def _summarize(self, entries: list[dict], previous_summary: str) -> str: + return self.compressor( + history_to_discard="\n".join(json.dumps(e) for e in entries), + previous_summary=previous_summary, + ).current_summary diff --git a/chatdku/core/dspy_classes/executor.py b/chatdku/core/dspy_classes/executor.py index 871c4d49..6d862fe7 100644 --- a/chatdku/core/dspy_classes/executor.py +++ b/chatdku/core/dspy_classes/executor.py @@ -1,3 +1,4 @@ +from datetime import date from typing import Any, Literal import dspy @@ -22,17 +23,23 @@ ) -class _ExecutorSignatureBase(dspy.Signature): - """ - You are an Executor Agent for Duke Kunshan University (DKU) gathering +class ExecutorSignatureBase(dspy.Signature): + """You are an Executor Agent for Duke Kunshan University (DKU) gathering information to answer a user's question. - Given the plan and the trajectory so far, do the following in order: + Given the plan and the tool results collected so far (trajectory), do the following in order: 1. Assess progress: - - What information from the plan has been successfully gathered. - - What information is still missing or insufficient. - - Whether the missing information can be obtained with available tools. + 1. What information from the plan has been successfully gathered. + 2. What information is still missing or insufficient. + 3. What NEW investigation areas have been REVEALED by the tool results so far + that were not in the original agenda — for example, a retrieved policy document + mentions a mandatory course, or schedule data reveals an unmet prerequisite chain. + + You MUST pursue the full current agenda — including any extensions discovered + during earlier steps — not just the original plan. If the assessment reveals + new requirements (e.g., a policy document names a mandatory course), investigate + those before finishing. 2. Decide whether to continue or finish: - Choose "finish" in the next_tool_name field if you have gathered enough @@ -61,8 +68,11 @@ class _ExecutorSignatureBase(dspy.Signature): etc. """ - plan: str = dspy.InputField( - desc="The plan describing what information to gather.", + current_agenda: str = dspy.InputField( + desc=( + "The current investigation agenda: the original plan plus any extensions " + "discovered during execution. This is the full set of things to pursue." + ), format=lambda x: x, ) current_user_message: str = dspy.InputField() @@ -70,9 +80,25 @@ class _ExecutorSignatureBase(dspy.Signature): desc="Tool calls and their results collected so far. May be empty on the first iteration.", format=lambda x: x, ) - conversation_history: str = CONVERSATION_HISTORY_FIELD conversation_summary: str = CONVERSATION_SUMMARY_FIELD + conversation_history: str = CONVERSATION_HISTORY_FIELD chatbot_role: str = ROLE_PROMPT + current_date: date = dspy.InputField() + assessment: str = dspy.OutputField( + desc=( + "Brief analysis: (1) what information has been gathered so far, " + "(2) what is still missing from the plan, " + "(3) whether the missing information can be obtained with available tools." + ), + format=lambda x: x, + ) + agenda_extensions: str = dspy.OutputField( + desc=( + "New investigation areas revealed by the tool results that are NOT yet " + "in the current agenda. Describe each as a short action phrase. " + "Leave empty if nothing new was discovered." + ), + ) assessment: str = dspy.OutputField( desc=( @@ -146,8 +172,8 @@ class SummarizerSignature(dspy.Signature): # Keys per iteration in the trajectory dict. -# assessment, thought, tool_name, tool_args, observation -_KEYS_PER_ITERATION = 5 +# tool_name, tool_args, observation +_KEYS_PER_ITERATION = 4 class Executor(dspy.Module): @@ -158,17 +184,22 @@ def __init__(self, tools, max_iterations=5): # Build the Executor signature dynamically with tool descriptions in the instructions. instr = ( - [f"{_ExecutorSignatureBase.instructions}\n"] - if _ExecutorSignatureBase.instructions + [f"{ExecutorSignatureBase.instructions}\n"] + if ExecutorSignatureBase.instructions else [] ) outputs = ", ".join( - [f"`{k}`" for k in _ExecutorSignatureBase.output_fields.keys()] + [f"`{k}`" for k in ExecutorSignatureBase.output_fields.keys()] ) + tools["finish"] = Tool( func=lambda: "Completed.", name="finish", - desc=f"Marks the task as complete. That is, signals that all information for producing the outputs, i.e. {outputs}, are now available to be extracted.", + desc=( + "Marks the task as complete." + "That is, signals that all information for producing " + f"the outputs, i.e. {outputs}, are now available to be extracted." + ), args={}, ) @@ -181,8 +212,8 @@ def __init__(self, tools, max_iterations=5): exec_signature = ( dspy.Signature( { - **_ExecutorSignatureBase.input_fields, - **_ExecutorSignatureBase.output_fields, + **ExecutorSignatureBase.input_fields, + **ExecutorSignatureBase.output_fields, }, "\n".join(instr), ) @@ -199,14 +230,6 @@ def __init__(self, tools, max_iterations=5): self.executor = dspy.Predict(exec_signature) self.distiller = dspy.Predict(DistillSignature) - self.executor_token_ratios: dict[str, float] = { - "plan": 2 / 13, - "current_user_message": 1 / 13, - "conversation_history": 1 / 13, - "conversation_summary": 1 / 13, - "chatbot_role": 2 / 13, - "trajectory": 5 / 13, - } self.distill_token_ratios: dict[str, float] = { "current_user_message": 2 / 10, "plan": 2 / 10, @@ -217,72 +240,72 @@ def __init__(self, tools, max_iterations=5): self.trajectory_summary = "" self.max_iterations = max_iterations - def get_token_limits(self, **kwargs) -> dict[str, int]: - """Return token limits using the executor's ratios.""" - template_len = len(get_template(self.executor, **kwargs)) - return token_limit_ratio_to_count(self.executor_token_ratios, template_len) - def forward( self, plan: str, current_user_message: str, conversation_memory: ConversationMemory, ) -> dspy.Prediction: - shared_inputs = dict( - current_user_message=current_user_message, - conversation_history=conversation_memory.history_str(), - conversation_summary=conversation_memory.summary, - ) + # current_agenda starts as the original plan and grows as the Executor + # discovers new investigation areas from tool results. + current_agenda = plan trajectory = {} with span_ctx_start("Executor", SpanKind.AGENT) as span: - span.set_attribute("agent.name", "Executor") - span.set_attribute( - "input.value", - safe_json_dumps({"plan": plan, **shared_inputs}), - ) - for idx in range(self.max_iterations): - exec_inputs = { - "plan": plan, - **shared_inputs, - "chatbot_role": role_str, - } - exec_inputs = truncate_tokens_all( - exec_inputs, - self._executor_token_limits(**exec_inputs), + executor_inputs = dict( + current_agenda=current_agenda, + current_user_message=current_user_message, + conversation_history=conversation_memory.history_str(), + conversation_summary=conversation_memory.summary, + current_date=str(date.today()), + chatbot_role=role_str, ) + span.set_attribute("agent.name", "Executor") + span.set_attribute("input.value", safe_json_dumps(executor_inputs)) + try: - result = self._call_with_potential_trajectory_truncation( - self.executor, trajectory, **exec_inputs + executor_result = ( + self._call_with_potential_trajectory_truncation( # noqa E501 + self.executor, trajectory, **executor_inputs + ) ) except ValueError: break - trajectory[f"assessment_{idx}"] = result.assessment - - if result.next_tool_name == "finish": + if executor_result.next_tool_name == "finish": break - trajectory[f"thought_{idx}"] = result.next_thought - trajectory[f"tool_name_{idx}"] = result.next_tool_name - trajectory[f"tool_args_{idx}"] = result.next_tool_args + # NOTE: By Temuulen - I don't think we need to record assessment + # The agent can just assess everyturn and the assessment can act like + # a thought process guideline + extensions = getattr(executor_result, "agenda_extensions", "").strip() + if extensions: + current_agenda = ( + f"{current_agenda}\n\n" + f"[Additional areas to investigate, discovered at step {idx + 1}]:\n" + f"{extensions}" + ) + trajectory[f"thought_{idx}"] = executor_result.next_thought + trajectory[f"tool_name_{idx}"] = executor_result.next_tool_name + trajectory[f"tool_args_{idx}"] = executor_result.next_tool_args try: trajectory[f"observation_{idx}"] = self.tools[ - result.next_tool_name - ](**result.next_tool_args) + executor_result.next_tool_name + ](**executor_result.next_tool_args) except Exception as err: trajectory[f"observation_{idx}"] = ( - f"Execution error in {result.next_tool_name}: {_fmt_exc(err)}" + f"Execution error in {executor_result.next_tool_name}: {_fmt_exc(err)}" ) - # DISTILL + # DISTILL — pass the final (extended) agenda so the distiller knows + # everything that was investigated, including any on-the-fly extensions. formatted_traj = format_trajectory(trajectory) distill_inputs = dict( current_user_message=current_user_message, - plan=plan, + plan=current_agenda, trajectory=formatted_traj, trajectory_summary=self.trajectory_summary, ) @@ -299,12 +322,6 @@ def forward( summary=self.trajectory_summary, ) - # Token limit helpers - - def _executor_token_limits(self, **kwargs) -> dict[str, int]: - template_len = len(get_template(self.executor, **kwargs)) - return token_limit_ratio_to_count(self.executor_token_ratios, template_len) - def _distill_token_limits(self, **kwargs) -> dict[str, int]: template_len = len(get_template(self.distiller, **kwargs)) return token_limit_ratio_to_count(self.distill_token_ratios, template_len) diff --git a/chatdku/core/dspy_classes/plan.py b/chatdku/core/dspy_classes/plan.py index cd9611b2..b157608c 100644 --- a/chatdku/core/dspy_classes/plan.py +++ b/chatdku/core/dspy_classes/plan.py @@ -56,16 +56,27 @@ class PlannerSignature(dspy.Signature): If any of these are missing from the current message and the conversation history, choose action_type "send_message" and ask for the missing information. - Once you have all three pieces of information, your plan should include: - a. Look up the requirements for the student's major. - b. Look up the university-wide common-core requirements. - c. Identify courses that still need to be completed. - d. Verify prerequisites for each recommended course. + Once you have all three pieces of information, your plan should: + a. FIRST retrieve year-specific academic policies for the student's + class year — e.g. query "Year 1 fall semester mandatory courses", + "freshman requirements DKU 101 writing", or + "Class of 20XX graduation requirements". Use VectorRetriever or + KeywordRetriever. The Executor will extend its agenda based on any + mandatory courses or policy constraints it discovers. + b. Call CourseRecommender with the student's major and completed_courses + to get the baseline eligibility and schedule availability report. + This single tool handles requirements lookup, schedule availability, + and prerequisite checking in one step — prefer it over calling + MajorRequirementsLookup, CourseScheduleLookup, and PrerequisiteLookup + individually. + c. Optionally supplement with VectorRetriever or QueryCurriculum if the + student asks for more detail on specific courses (syllabus, instructor, + course description, etc.). """ current_user_message: str = dspy.InputField() - conversation_history: str = CONVERSATION_HISTORY_FIELD conversation_summary: str = CONVERSATION_SUMMARY_FIELD + conversation_history: str = CONVERSATION_HISTORY_FIELD chatbot_role: str = ROLE_PROMPT available_tools: str = dspy.InputField( desc="Descriptions of the tools available to the Executor.", @@ -141,6 +152,29 @@ class PlannerSignature(dspy.Signature): "3. What courses have you already completed or are currently taking?" ), ).with_inputs("current_user_message"), + dspy.Example( + current_user_message=( + "I'm a Data Science major, Class of 2027. " + "I've completed MATH 105, STATS 201, COMPSCI 101, and ECON 101. " + "What courses should I take next semester?" + ), + action_type="plan", + action=( + "The student has provided all required information: major (Data Science), " + "year (Class of 2027), completed courses (MATH 105, STATS 201, COMPSCI 101, ECON 101). " + "Class of 2027 means they matriculated in Fall 2023, so they are currently in Year 2 " + "(assuming Fall 2026 is next semester) or Year 3 depending on current date. " + "Step 1: Retrieve year-specific academic policies — search for 'Year 2 requirements' " + "or 'sophomore mandatory courses DKU' to identify any mandatory courses the student " + "must take based on their class year (e.g. DKU 101, writing requirement for Year 1; " + "GCHINA 101 for Year 1 Spring; GLOCHALL 201 for Year 2). " + "Step 2: Call CourseRecommender with major='data science' and " + "completed_courses=['MATH 105', 'STATS 201', 'COMPSCI 101', 'ECON 101'] to get the " + "baseline eligibility and schedule availability report. " + "The Executor should extend its agenda if the policy search reveals mandatory courses " + "not yet covered by CourseRecommender." + ), + ).with_inputs("current_user_message"), ] diff --git a/chatdku/core/tools/course_recommender.py b/chatdku/core/tools/course_recommender.py new file mode 100644 index 00000000..30250ef0 --- /dev/null +++ b/chatdku/core/tools/course_recommender.py @@ -0,0 +1,524 @@ +""" +course_recommender.py + +Deterministic tool that combines major requirements, schedule availability, +and prerequisite data to produce a structured next-semester course recommendation +for a given student. + +This replaces the need for 20+ individual executor tool-call iterations by doing +all the data-joining logic in Python, returning a single structured report. +""" + +from __future__ import annotations + +import re +from pathlib import Path + +import pandas as pd +from openinference.instrumentation import safe_json_dumps +from openinference.semconv.trace import ( + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, +) +from opentelemetry.trace import Status, StatusCode + +from chatdku.core.tools.major_requirements import _best_match, _list_stems +from chatdku.core.utils import span_ctx_start +from chatdku.config import config + +# --------------------------------------------------------------------------- +# Course code parsing +# --------------------------------------------------------------------------- + +# Matches DKU course codes like COMPSCI 201, STATS 202A, MATH 105. +# Handles subject codes of 2-10 uppercase letters followed by a 3-digit +# catalog number with an optional trailing letter (e.g. 101A). +_COURSE_CODE_RE = re.compile(r"\b([A-Z]{2,10})\s+(\d{3}[A-Z]?)\b") + +# Known DKU subject codes — used to filter false positives from the markdown. +_KNOWN_SUBJECTS = { + "DKU", + "GERMAN", + "INDSTU", + "JAPANESE", + "KOREAN", + "MUSIC", + "SPANISH", + "ARHU", + "ARTS", + "BEHAVSCI", + "BIOL", + "CHEM", + "CHINESE", + "COMPDSGN", + "COMPSCI", + "CULANTH", + "CULMOVE", + "CULSOC", + "EAP", + "ECON", + "ENVIR", + "ETHLDR", + "GCHINA", + "GCULS", + "GLHLTH", + "GLOCHALL", + "HIST", + "HUM", + "INFOSCI", + "INSTGOV", + "LIT", + "MATH", + "MATSCI", + "MEDIA", + "MEDIART", + "NEUROSCI", + "PHIL", + "PHYS", + "PHYSEDU", + "POLECON", + "POLSCI", + "PPE", + "PSYCH", + "PUBPOL", + "SOCIOL", + "SOSC", + "STATS", + "USTUD", + "WOC", + "RELIG", + "MINITERM", +} + + +def parse_course_codes(md_text: str) -> list[str]: + """Extract all DKU course codes from a Markdown requirements document. + + Returns a deduplicated list of strings like ["COMPSCI 201", "STATS 202"]. + Only returns codes whose subject prefix is a known DKU subject code, to + filter out false positives (e.g. headings that accidentally match the regex). + """ + found = [] + for subject, catalog in _COURSE_CODE_RE.findall(md_text): + if subject in _KNOWN_SUBJECTS: + found.append(f"{subject} {catalog}") + # Deduplicate while preserving order. + seen: set[str] = set() + result = [] + for code in found: + if code not in seen: + seen.add(code) + result.append(code) + return result + + +# --------------------------------------------------------------------------- +# Prerequisite satisfaction +# --------------------------------------------------------------------------- + + +def _load_prereq_df(prereq_csv_path: Path) -> pd.DataFrame: + return pd.read_csv(prereq_csv_path, encoding="utf-16le") + + +def _get_prereq_text(course: str, prereq_df: pd.DataFrame) -> str | None: + """Return the raw prerequisite description for *course*, or None if absent.""" + parts = re.sub(r"[\s\-]", "_", course.strip()).split("_") + subject = parts[0].upper() + catalog = "".join(parts[1:]) + + mask = (prereq_df.iloc[:, 2].astype(str).str.strip() == subject) & ( + prereq_df.iloc[:, 3].astype(str).str.strip() == catalog + ) + if not mask.any(): + return None + + matched = prereq_df.loc[mask].copy() + matched["_eff_date"] = pd.to_datetime( + matched.iloc[:, 1].astype(str).str.strip(), format="%m/%d/%Y", errors="coerce" + ) + latest = matched.sort_values("_eff_date", ascending=False).iloc[0] + descr = latest.iloc[13] + if pd.notna(descr) and str(descr).strip(): + return str(descr).strip() + return None + + +def prerequisites_met( + course: str, + completed_set: set[str], + prereq_df: pd.DataFrame, +) -> tuple[bool, str]: + """Check whether a student's completed courses satisfy *course*'s prerequisites. + + Returns: + (True, "") — no prerequisites or all satisfied + (False, "") — prerequisites not met + (True, "") — best-effort: possible OR-path satisfied + + Strategy (best-effort on free-form text): + 1. Extract all course codes mentioned in the prereq text. + 2. If the text contains "or": eligible if ANY mentioned code is completed. + 3. Otherwise (AND / simple): eligible if ALL mentioned codes are completed. + 4. If no codes are found in the prereq text, assume no structured prerequisite + and return eligible (the raw text is included for the Synthesizer). + """ + text = _get_prereq_text(course, prereq_df) + if text is None: + return True, "" + + # Strip anti-requisite section so its course codes aren't treated as prerequisites. + prereq_text = re.split(r"[Aa]nti[\s\-]?[Rr]equisite", text)[0].strip() + + # Extract all course codes mentioned in the prerequisite portion only. + codes_in_prereq = [ + f"{s} {c}" + for s, c in _COURSE_CODE_RE.findall(prereq_text) + if s in _KNOWN_SUBJECTS + ] + + if not codes_in_prereq: + # No structured codes — can't verify; pass through with a note. + return True, f"(Unstructured prerequisite — verify manually: {text})" + + has_or = " or " in prereq_text.lower() + + if has_or: + satisfied = any(c in completed_set for c in codes_in_prereq) + if satisfied: + return True, "" + missing = [c for c in codes_in_prereq if c not in completed_set] + return ( + False, + f"Requires one of: {', '.join(codes_in_prereq)} (missing: {', '.join(missing)})", + ) + else: + missing = [c for c in codes_in_prereq if c not in completed_set] + if not missing: + return True, "" + return ( + False, + f"Requires: {', '.join(codes_in_prereq)} (missing: {', '.join(missing)})", + ) + + +# --------------------------------------------------------------------------- +# Schedule lookup (batch) +# --------------------------------------------------------------------------- + + +def _get_offered_courses( + course_codes: list[str], classdata_csv_path: Path +) -> dict[str, list[dict]]: + """Return a mapping of course_code → list of schedule rows for offered courses. + + Courses not found in the schedule CSV are omitted from the result. + """ + try: + df = pd.read_csv(classdata_csv_path) + except FileNotFoundError: + return {} + + result: dict[str, list[dict]] = {} + for code in course_codes: + # Parse subject and catalog from code like "COMPSCI 201" + parts = code.strip().split() + if len(parts) != 2: + continue + subject, catalog = parts[0].upper(), parts[1].upper() + mask = (df["Subject"].astype(str).str.strip().str.upper() == subject) & ( + df["Catalog"].astype(str).str.strip().str.upper() == catalog + ) + rows = df.loc[mask].to_dict(orient="records") + if rows: + result[code] = rows + return result + + +_DAY_COLS = [("Mon", "M"), ("Tues", "Tu"), ("Wed", "W"), ("Thurs", "Th"), ("Fri", "F")] + + +def _format_schedule_rows(rows: list[dict]) -> str: + """Produce a compact, human-readable schedule summary for a course. + + Handles the real cleaned_classdata.csv column layout: + - Day columns: Mon / Tues / Wed / Thurs / Fri (value "Y" or "N") + - Time columns: Mtg Start / Mtg End + """ + # Deduplicate by section to avoid listing lab/recitation rows as separate entries. + seen_sections: set[str] = set() + parts = [] + for row in rows: + section = str(row.get("Section", "")).strip() + if section in seen_sections: + continue + seen_sections.add(section) + + session = str(row.get("Session", "")).strip() + start = str(row.get("Mtg Start", "")).strip().rstrip("0").rstrip(":") + end = str(row.get("Mtg End", "")).strip().rstrip("0").rstrip(":") + instructor = str(row.get("Instructor", "")).strip() + status = str(row.get("Class Status", "")).strip() + + # Build day string from individual boolean columns. + days = "".join( + abbr + for col, abbr in _DAY_COLS + if str(row.get(col, "N")).strip().upper() == "Y" + ) + + line_parts = [] + if section: + line_parts.append(f"§{section}") + if session: + line_parts.append(f"({session})") + if days: + time_str = f"{start}–{end}" if start and end else (start or end) + line_parts.append(f"{days} {time_str}" if time_str else days) + if instructor: + line_parts.append(f"Instr: {instructor}") + if status and status.lower() != "active": + line_parts.append(f"[{status}]") + parts.append(", ".join(line_parts)) + return " | ".join(parts) + + +# --------------------------------------------------------------------------- +# Tool factory +# --------------------------------------------------------------------------- + + +def CourseRecommender( + major: str, + completed_courses: list[str], +) -> str: + """ + Generate a structured next-semester course recommendation for a DKU student. + + Given the student's major and the courses they have already completed, + this tool: + 1. Looks up the graduation requirements for the student's major. + 2. Looks up the university-wide common-core requirements. + 3. Identifies which required courses still need to be completed. + 4. Checks which remaining courses are offered next semester. + 5. Checks whether the student meets prerequisites for each available course. + 6. Returns a grouped report: recommended, eligible-but-not-offered, + prerequisites-not-met, and no-schedule-data categories. + + Args: + major (str): The student's major and optional track, e.g. "data science" + or "computation and design computer science". + completed_courses (list[str]): Courses the student has already completed + or is currently taking, e.g. ["COMPSCI 101", "MATH 105", "STATS 201"]. + + Returns: + A Markdown-formatted recommendation report. + """ + req_dir = Path(config.major_req_dir) + classdata_csv_path = Path(config.classdata_csv_path) + prereq_csv_path = Path(config.prereq_csv_path) + with span_ctx_start("CourseRecommender", OpenInferenceSpanKindValues.TOOL) as span: + span.set_attributes( + { + SpanAttributes.INPUT_VALUE: safe_json_dumps( + { + "major": major, + "completed_courses": completed_courses, + } + ), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + + try: + result = _run_recommendation( + major=major, + completed_courses=completed_courses, + req_dir=req_dir, + classdata_csv_path=classdata_csv_path, + prereq_csv_path=prereq_csv_path, + ) + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: result[:500], + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, + } + ) + span.set_status(Status(StatusCode.OK)) + return result + + except Exception as e: + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: safe_json_dumps({"error": str(e)}), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + span.set_status(Status(StatusCode.ERROR), str(e)) + raise e + + +# --------------------------------------------------------------------------- +# Core recommendation logic +# --------------------------------------------------------------------------- + + +def _run_recommendation( + major: str, + completed_courses: list[str], + req_dir: Path, + classdata_csv_path: Path, + prereq_csv_path: Path, +) -> str: + if not req_dir.is_dir(): + raise FileNotFoundError(f"Requirements directory not found: {req_dir}") + + stems = _list_stems(req_dir) + + # --- 1. Load major requirements --- + matched_major = _best_match(major, stems) + if matched_major is None: + return ( + f"No matching major found for '{major}'. " + "Please check the major name and try again." + ) + major_md = (req_dir / f"{matched_major}.md").read_text(encoding="utf-8") + major_courses = parse_course_codes(major_md) + + # --- 2. Load common-core requirements --- + common_core_md = "" + common_core_stem = _best_match("requirements for all majors", stems) + common_core_courses: list[str] = [] + if common_core_stem: + common_core_md = (req_dir / f"{common_core_stem}.md").read_text( + encoding="utf-8" + ) + common_core_courses = parse_course_codes(common_core_md) + + # --- 3. Compute remaining required courses --- + completed_set = {c.strip().upper() for c in completed_courses} + # Normalize completed_courses to "SUBJECT NNN" format for comparison. + # Users might type "CS 101" or "compsci 101" — handle by uppercasing. + all_required = list( + dict.fromkeys(major_courses + common_core_courses) + ) # preserve order, deduplicate + remaining = [c for c in all_required if c.upper() not in completed_set] + + if not remaining: + return ( + f"## Course Recommendation for {matched_major}\n\n" + "You have completed all required courses for this major. " + "Consider taking electives or checking with your advisor about graduation requirements." + ) + + # --- 4. Check schedule availability --- + offered = _get_offered_courses(remaining, classdata_csv_path) + + # --- 5. Check prerequisites for offered courses --- + try: + prereq_df = _load_prereq_df(prereq_csv_path) + prereq_available = True + except Exception: + prereq_df = None + prereq_available = False + + eligible_and_offered: list[tuple[str, str]] = [] # (course, schedule_summary) + eligible_not_offered: list[str] = [] + not_eligible: list[tuple[str, str]] = [] # (course, reason) + no_schedule_data: list[str] = [] + + for course in remaining: + if course in offered: + if prereq_available: + met, reason = prerequisites_met(course, completed_set, prereq_df) + if met: + schedule_summary = _format_schedule_rows(offered[course]) + eligible_and_offered.append((course, schedule_summary)) + else: + not_eligible.append((course, reason)) + else: + schedule_summary = _format_schedule_rows(offered[course]) + eligible_and_offered.append((course, schedule_summary)) + else: + if prereq_available: + met, reason = prerequisites_met(course, completed_set, prereq_df) + if met: + eligible_not_offered.append(course) + else: + not_eligible.append((course, reason)) + else: + # No schedule and no prereq data — just report as not offered. + no_schedule_data.append(course) + + # --- 6. Build report --- + lines = [f"## Course Recommendation for {matched_major}\n"] + lines.append(f"**Matched requirements file:** `{matched_major}.md`") + lines.append(f"**Total required courses:** {len(all_required)}") + lines.append(f"**Completed:** {len(all_required) - len(remaining)}") + lines.append(f"**Remaining:** {len(remaining)}\n") + + lines.append("### Recommended — eligible and offered next semester") + if eligible_and_offered: + for course, schedule in eligible_and_offered: + lines.append(f"- **{course}** — {schedule}") + else: + lines.append("- *(none)*") + lines.append("") + + lines.append("### Eligible but not offered next semester") + if eligible_not_offered: + for course in eligible_not_offered: + lines.append(f"- {course}") + else: + lines.append("- *(none)*") + lines.append("") + + lines.append("### Not eligible — prerequisites not yet met") + if not_eligible: + for course, reason in not_eligible: + lines.append(f"- **{course}**: {reason}") + else: + lines.append("- *(none)*") + lines.append("") + + if no_schedule_data: + lines.append("### Required courses with no schedule data") + for course in no_schedule_data: + lines.append(f"- {course}") + lines.append("") + + lines.append("---") + lines.append( + "*Note: Prerequisites are checked using DKUHub data. " + "Complex or conditional prerequisites (e.g. instructor consent, GPA requirements) " + "may not be captured — always confirm with your academic advisor.*" + ) + + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# CLI smoke-test +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + from chatdku.setup import use_phoenix + + use_phoenix() + + parser = argparse.ArgumentParser(description="Test CourseRecommender") + parser.add_argument("--major", required=True, help="Student's major") + parser.add_argument( + "--completed", + nargs="*", + default=[], + help="Completed course codes, e.g. COMPSCI 101 MATH 105", + ) + args = parser.parse_args() + + __import__("pprint").pprint( + CourseRecommender(major=args.major, completed_courses=args.completed) + ) diff --git a/chatdku/core/tools/course_schedule.py b/chatdku/core/tools/course_schedule.py index b4566d62..7abf2f9b 100644 --- a/chatdku/core/tools/course_schedule.py +++ b/chatdku/core/tools/course_schedule.py @@ -17,8 +17,8 @@ ) from opentelemetry.trace import Status, StatusCode -from chatdku.setup import use_phoenix from chatdku.core.utils import span_ctx_start +from chatdku.config import config # --------------------------------------------------------------------------- @@ -67,104 +67,89 @@ def _lookup(course_raw: str, df: pd.DataFrame) -> list[dict]: # --------------------------------------------------------------------------- -def CourseScheduleLookupOuter(classdata_csv_path: str): +def CourseScheduleLookup(course_names: list[str]) -> str: """ - DSPy tool factory for looking up next-semester course schedule data. + Look up the course schedule for one or more courses at Duke Kunshan University. - Args: - classdata_csv_path: Path to the cleaned class-data CSV - (produced by scripts/clean_classdata.py). - """ - - def CourseScheduleLookup(course_names: list[str]) -> str: - """ - Look up the course schedule for one or more courses at Duke Kunshan University. + Given a list of course codes (e.g. ["COMPSCI 101", "CHINESE 101A"]), + returns the schedule information (sections, times, instructors, enrollment, + etc.) for each course from the upcoming semester's class data. - Given a list of course codes (e.g. ["COMPSCI 101", "CHINESE 101A"]), - returns the schedule information (sections, times, instructors, enrollment, - etc.) for each course from the upcoming semester's class data. + Handles formatting variations such as "COMPSCI101" or "COMPSCI-101". - Handles formatting variations such as "COMPSCI101" or "COMPSCI-101". - - Args: - course_names (list[str]): Courses to look up, e.g. ["STATS 202", "BIOL 305"]. + Args: + course_names (list[str]): Courses to look up, e.g. ["STATS 202", "BIOL 305"]. - Returns: - JSON string with schedule rows for every matched course, - or an informative message when a course is not found. - """ - with span_ctx_start( - "CourseScheduleLookup", OpenInferenceSpanKindValues.TOOL - ) as span: + Returns: + JSON string with schedule rows for every matched course, + or an informative message when a course is not found. + """ + classdata_csv_path = config.classdata_csv_path + with span_ctx_start( + "CourseScheduleLookup", OpenInferenceSpanKindValues.TOOL + ) as span: + span.set_attributes( + { + SpanAttributes.INPUT_VALUE: safe_json_dumps( + dict(course_names=course_names) + ), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + + try: + df = pd.read_csv(classdata_csv_path) + except FileNotFoundError: + msg = f"Course schedule data file not found: {classdata_csv_path}" span.set_attributes( { - SpanAttributes.INPUT_VALUE: safe_json_dumps( - dict(course_names=course_names) - ), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(dict(error=msg)), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, } ) + span.set_status(Status(StatusCode.ERROR), msg) + raise FileNotFoundError(msg) + + try: + results: dict[str, list[dict] | str] = {} + for course in course_names: + rows = _lookup(course, df) + if rows: + results[course] = rows + else: + results[course] = f"No schedule found for '{course}'." + + output = json.dumps(results, default=str) + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: output, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + span.set_status(Status(StatusCode.OK)) + return output - try: - df = pd.read_csv(classdata_csv_path) - except FileNotFoundError: - msg = f"Course schedule data file not found: {classdata_csv_path}" - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: safe_json_dumps(dict(error=msg)), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } - ) - span.set_status(Status(StatusCode.ERROR), msg) - raise FileNotFoundError(msg) - - try: - results: dict[str, list[dict] | str] = {} - for course in course_names: - rows = _lookup(course, df) - if rows: - results[course] = rows - else: - results[course] = f"No schedule found for '{course}'." - - output = json.dumps(results, default=str) - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: output, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } - ) - span.set_status(Status(StatusCode.OK)) - return output - - except Exception as e: - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: safe_json_dumps( - dict(error=str(e)) - ), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } - ) - span.set_status(Status(StatusCode.ERROR), str(e)) - raise - - return CourseScheduleLookup + except Exception as e: + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(dict(error=str(e))), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + span.set_status(Status(StatusCode.ERROR), str(e)) + raise # --------------------------------------------------------------------------- # CLI smoke-test # --------------------------------------------------------------------------- if __name__ == "__main__": + from chatdku.setup import use_phoenix + use_phoenix() import argparse parser = argparse.ArgumentParser(description="Test CourseScheduleLookup") - parser.add_argument( - "--csv", - default="/datapool/chatdku_external_data/cleaned_classdata.csv", - help="Path to the cleaned class-data CSV", - ) parser.add_argument( "courses", nargs="+", @@ -172,5 +157,4 @@ def CourseScheduleLookup(course_names: list[str]) -> str: ) args = parser.parse_args() - lookup = CourseScheduleLookupOuter(args.csv) - __import__("pprint").pprint((lookup(args.courses))) + __import__("pprint").pprint(CourseScheduleLookup(args.courses)) diff --git a/chatdku/core/tools/get_prerequisites.py b/chatdku/core/tools/get_prerequisites.py index b3ccb410..efbfa260 100644 --- a/chatdku/core/tools/get_prerequisites.py +++ b/chatdku/core/tools/get_prerequisites.py @@ -11,6 +11,7 @@ from opentelemetry.trace import Status, StatusCode from chatdku.core.utils import span_ctx_start +from chatdku.config import config logger = logging.getLogger(__name__) @@ -49,81 +50,55 @@ def get_prereq(course: str, data_file_path: str) -> str: return f"Unknown error in finding prerequisite for {course}." -def PrerequisiteLookupOuter(prereq_csv_path: str): +def PrerequisiteLookup(course_names: list[str]) -> str: """ - DSPy tool factory for looking up course prerequisites. - Returns a Phoenix-observable callable for the agent's tool list. + Look up the prerequisites for one or more courses at Duke Kunshan University. - Args: - prereq_csv_path: Path to the prerequisites CSV file. - """ + Given a list of course names (e.g. ["STATS 202", "COMPSCI 201", "MATH 206"]), + returns the prerequisite and anti-requisite requirements for each course. + Uses the latest available version of the course requirements. - def PrerequisiteLookup(course_names: list[str]) -> str: - """ - Look up the prerequisites for one or more courses at Duke Kunshan University. + Good tool for answering questions about what courses are needed + before taking specific courses. - Given a list of course names (e.g. ["STATS 202", "COMPSCI 201", "MATH 206"]), - returns the prerequisite and anti-requisite requirements for each course. - Uses the latest available version of the course requirements. - - Good tool for answering questions about what courses are needed - before taking specific courses. + Args: + course_names (list[str]): The courses to look up, e.g. ["STATS 202", "COMPSCI 101"]. - Args: - course_names (list[str]): The courses to look up, e.g. ["STATS 202", "COMPSCI 101"]. + Returns: + String describing the prerequisites for each course, separated by newlines. + """ + prereq_csv_path = config.prereq_csv_path + with span_ctx_start("PrerequisiteLookup", OpenInferenceSpanKindValues.TOOL) as span: + span.set_attributes( + { + SpanAttributes.INPUT_VALUE: safe_json_dumps( + dict(course_names=course_names) + ), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) - Returns: - String describing the prerequisites for each course, separated by newlines. - """ - with span_ctx_start( - "PrerequisiteLookup", OpenInferenceSpanKindValues.TOOL - ) as span: + try: + results = [get_prereq(course, prereq_csv_path) for course in course_names] + result = "\n".join(results) span.set_attributes( { - SpanAttributes.INPUT_VALUE: safe_json_dumps( - dict(course_names=course_names) - ), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(dict(result=result)), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, } ) - - try: - results = [ - get_prereq(course, prereq_csv_path) for course in course_names - ] - result = "\n".join(results) - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: safe_json_dumps( - dict(result=result) - ), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } - ) - span.set_status(Status(StatusCode.OK)) - return result - except Exception as e: - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: safe_json_dumps( - dict(error=str(e)) - ), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } - ) - span.set_status(Status(StatusCode.ERROR), str(e)) - raise e - - return PrerequisiteLookup + span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(dict(error=str(e))), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + span.set_status(Status(StatusCode.ERROR), str(e)) + raise e if __name__ == "__main__": - import os - - local_csv = os.path.join( - os.path.dirname(__file__), - "chatdku_external_data", - "DK_SR_PREREQ_CRSE_CHATDKU.csv", - ) - lookup = PrerequisiteLookupOuter(local_csv) - print(lookup(["STATS 202", "COMPSCI 201"])) + print(PrerequisiteLookup(["STATS 202", "COMPSCI 201"])) diff --git a/chatdku/core/tools/major_requirements.py b/chatdku/core/tools/major_requirements.py index a08eab16..1270a89d 100644 --- a/chatdku/core/tools/major_requirements.py +++ b/chatdku/core/tools/major_requirements.py @@ -31,8 +31,10 @@ SpanAttributes, ) from opentelemetry.trace import Status, StatusCode +from thefuzz import fuzz, process from chatdku.core.utils import span_ctx_start +from chatdku.config import config logger = logging.getLogger(__name__) @@ -42,42 +44,49 @@ # --------------------------------------------------------------------------- -def _tokenize(s: str) -> set[str]: - """Lowercase, drop punctuation/separators, return word-token set.""" - s = s.lower() - s = re.sub(r"[/\\&,\-]", " ", s) - s = re.sub(r"[^a-z0-9 ]", "", s) - return set(s.split()) +def _build_stem_dict(stems: list[str]) -> dict[str, str]: + """Build a dictionary of stems to their normalized versions. + For example: {"data-science": "data science"} + """ + + def _replace_hyphens(stem: str) -> str: + return [stem.replace("-", " ")] + + stem_dict = {} + for stem in stems: + stem_dict[stem] = _replace_hyphens(stem) + return stem_dict + +def _clean_query(query: str) -> str: + query = query.lower() + query = re.sub(r"[/\\&,\-]", " ", query) + query = re.sub(r"[^a-z0-9 ]", "", query) + return query -def _jaccard(a: set[str], b: set[str]) -> float: - union = a | b - if not union: - return 0.0 - return len(a & b) / len(union) + +_MIN_MATCH_SCORE = 40 # below this, treat as no match def _best_match(query: str, stems: list[str]) -> str | None: """ - Return the filename stem that best matches *query* by Jaccard similarity - on word tokens. Returns None when no candidate shares any token with - the query. + Return the filename stem that best matches *query* by token-set ratio. + Returns None when the best score is below _MIN_MATCH_SCORE. """ - q_tokens = _tokenize(query) - if not q_tokens: - return None - - best_stem: str | None = None - best_score = 0.0 - - for stem in stems: - c_tokens = _tokenize(stem) - score = _jaccard(q_tokens, c_tokens) - if score > best_score: - best_score = score - best_stem = stem + stems_dict = _build_stem_dict(stems) + query = _clean_query(query) + + matches = process.extract( + query, + stems_dict, + scorer=fuzz.token_set_ratio, + limit=1, + ) - return best_stem if best_score > 0.0 else None + if not matches: + return None + score, key = matches[0][1], matches[0][2] + return key if score >= _MIN_MATCH_SCORE else None def _list_stems(requirements_dir: Path) -> list[str]: @@ -89,120 +98,105 @@ def _list_stems(requirements_dir: Path) -> list[str]: # --------------------------------------------------------------------------- -def MajorRequirementsLookupOuter(requirements_dir: str): +def MajorRequirementsLookup(major: str) -> str: """ - DSPy tool factory for looking up DKU major/track degree requirements. + Look up the graduation requirements for a Duke Kunshan University major. - Args: - requirements_dir: Path to the directory containing per-major - Markdown files from the UG Bulletin. - """ - req_dir = Path(requirements_dir) - - def MajorRequirementsLookup(major: str) -> str: - """ - Look up the graduation requirements for a Duke Kunshan University major. - - Given a major (and optional track) name, returns the full list of - required and elective courses from the UG Bulletin (2023-2024). - - Pass major="list" to get the names of all available majors/tracks. - Pass major="requirements for all majors" to retrieve the university-wide - core requirements that every student must complete regardless of major. - - Examples of valid major strings: - "data science" - "computation and design / computer science" - "behavioral science psychology" - "global health biology" - "requirements for all majors" - "list" - - Args: - major (str): Major (and optionally track) name, e.g. "data science". - Pass "list" to enumerate available majors. - - Returns: - Markdown text of the major's requirements, or an error message when - no match is found. - """ - with span_ctx_start( - "MajorRequirementsLookup", OpenInferenceSpanKindValues.TOOL - ) as span: - span.set_attributes( - { - SpanAttributes.INPUT_VALUE: safe_json_dumps(dict(major=major)), - SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, - } - ) + Given a major (and optional track) name, returns the full list of + required and elective courses from the UG Bulletin (2023-2024). - try: - if not req_dir.is_dir(): - raise FileNotFoundError( - f"Requirements directory not found: {req_dir}" - ) - - stems = _list_stems(req_dir) - if not stems: - raise FileNotFoundError(f"No requirement files found in {req_dir}") - - # Special: list all available majors - if major.strip().lower() == "list": - result = "Available DKU majors/tracks:\n" + "\n".join( - f" - {s}" for s in stems - ) - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: result, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, - } - ) - span.set_status(Status(StatusCode.OK)) - return result - - matched = _best_match(major, stems) - if matched is None: - result = ( - f"No matching major found for '{major}'. " - "Call with major='list' to see all available majors." - ) - span.set_attributes( - { - SpanAttributes.OUTPUT_VALUE: result, - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, - } - ) - span.set_status(Status(StatusCode.OK)) - return result - - md_path = req_dir / f"{matched}.md" - content = md_path.read_text(encoding="utf-8") - result = f"# Requirements: {matched}\n\n{content}" + Pass major="list" to get the names of all available majors/tracks. + Pass major="requirements for all majors" to retrieve the university-wide + core requirements that every student must complete regardless of major. + Examples of valid major strings: + "data science" + "computation and design / computer science" + "behavioral science psychology" + "global health biology" + "requirements for all majors" + "list" + + Args: + major (str): Major (and optionally track) name, e.g. "data science". + Pass "list" to enumerate available majors. + + Returns: + Markdown text of the major's requirements, or an error message when + no match is found. + """ + req_dir = Path(config.major_req_dir) + with span_ctx_start( + "MajorRequirementsLookup", OpenInferenceSpanKindValues.TOOL + ) as span: + span.set_attributes( + { + SpanAttributes.INPUT_VALUE: safe_json_dumps(dict(major=major)), + SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + + try: + if not req_dir.is_dir(): + raise FileNotFoundError(f"Requirements directory not found: {req_dir}") + + stems = _list_stems(req_dir) + if not stems: + raise FileNotFoundError(f"No requirement files found in {req_dir}") + + # Special: list all available majors + if major.strip().lower() == "list": + result = "Available DKU majors/tracks:\n" + "\n".join( + f" - {s}" for s in stems + ) span.set_attributes( { - SpanAttributes.OUTPUT_VALUE: safe_json_dumps( - dict(matched_file=matched, char_count=len(result)) - ), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: result, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, } ) span.set_status(Status(StatusCode.OK)) return result - except Exception as e: + matched = _best_match(major, stems) + if matched is None: + result = ( + f"No matching major found for '{major}'. " + "Call with major='list' to see all available majors." + ) span.set_attributes( { - SpanAttributes.OUTPUT_VALUE: safe_json_dumps( - dict(error=str(e)) - ), - SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + SpanAttributes.OUTPUT_VALUE: result, + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.TEXT.value, } ) - span.set_status(Status(StatusCode.ERROR), str(e)) - raise + span.set_status(Status(StatusCode.OK)) + return result - return MajorRequirementsLookup + md_path = req_dir / f"{matched}.md" + content = md_path.read_text(encoding="utf-8") + result = f"# Requirements: {matched}\n\n{content}" + + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: safe_json_dumps( + dict(matched_file=matched, char_count=len(result)) + ), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + span.set_status(Status(StatusCode.OK)) + return result + + except Exception as e: + span.set_attributes( + { + SpanAttributes.OUTPUT_VALUE: safe_json_dumps(dict(error=str(e))), + SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value, + } + ) + span.set_status(Status(StatusCode.ERROR), str(e)) + raise e # --------------------------------------------------------------------------- @@ -211,14 +205,13 @@ def MajorRequirementsLookup(major: str) -> str: if __name__ == "__main__": import argparse + from chatdku.setup import use_phoenix + + use_phoenix() + parser = argparse.ArgumentParser(description="Test MajorRequirementsLookup") - parser.add_argument( - "--dir", - default="/datapool/chatdku_external_data/doc_testing/output/ug_bulletin_2023-2024", - help="Path to the requirements markdown directory", - ) parser.add_argument("--major", required=True, help="Major name to look up") args = parser.parse_args() - lookup = MajorRequirementsLookupOuter(args.dir) - print(lookup(args.major)) + lookup = MajorRequirementsLookup(args.major) + __import__("pprint").pprint(lookup) diff --git a/chatdku/core/tools/retriever/keyword_retriever.py b/chatdku/core/tools/retriever/keyword_retriever.py index f24a3c5e..4bc667be 100644 --- a/chatdku/core/tools/retriever/keyword_retriever.py +++ b/chatdku/core/tools/retriever/keyword_retriever.py @@ -4,12 +4,8 @@ import sys from itertools import combinations -import nltk -from nltk.corpus import stopwords -from nltk.tokenize import word_tokenize from redis import Redis from redis.commands.search.query import Query -from redisvl.schema import IndexSchema from chatdku.config import config from chatdku.core.tools.retriever.base_retriever import BaseDocRetriever, NodeWithScore @@ -17,6 +13,8 @@ def _ensure_nltk_resource(resource_path: str, download_name: str) -> None: + import nltk + try: nltk.data.find(resource_path) except LookupError: @@ -34,8 +32,16 @@ def _ensure_nltk_resource(resource_path: str, download_name: str) -> None: ) -_ensure_nltk_resource("corpora/stopwords", "stopwords") -_ensure_nltk_resource("tokenizers/punkt_tab", "punkt_tab") +_nltk_ready = False + + +def _ensure_nltk_resources() -> None: + global _nltk_ready + if _nltk_ready: + return + _ensure_nltk_resource("corpora/stopwords", "stopwords") + _ensure_nltk_resource("tokenizers/punkt_tab", "punkt_tab") + _nltk_ready = True class KeywordRetriever(BaseDocRetriever): @@ -52,12 +58,23 @@ def __init__( search_mode, files, ) + # Load NLTK resources once at construction time so the first query + # doesn't pay the cost. Subsequent calls are O(1) via sys.modules. + _ensure_nltk_resources() + from nltk.corpus import stopwords as _sw + from nltk.tokenize import word_tokenize as _wt + + self._stopwords = _sw + self._word_tokenize = _wt def query(self, query: str | list[str]) -> list[NodeWithScore]: """ Retrieve texts from the database that contain the same keywords in the query. """ + stopwords = self._stopwords + word_tokenize = self._word_tokenize + client = Redis( host=config.redis_host, port=config.redis_port, diff --git a/chatdku/core/tools/syllabi_tool/classes_schema.json b/chatdku/core/tools/syllabi/classes_schema.json similarity index 100% rename from chatdku/core/tools/syllabi_tool/classes_schema.json rename to chatdku/core/tools/syllabi/classes_schema.json diff --git a/chatdku/core/tools/syllabi_tool/create_chatdku_readonly_user.sql b/chatdku/core/tools/syllabi/create_chatdku_readonly_user.sql similarity index 100% rename from chatdku/core/tools/syllabi_tool/create_chatdku_readonly_user.sql rename to chatdku/core/tools/syllabi/create_chatdku_readonly_user.sql diff --git a/chatdku/core/tools/syllabi_tool/create_table.sql b/chatdku/core/tools/syllabi/create_table.sql similarity index 100% rename from chatdku/core/tools/syllabi_tool/create_table.sql rename to chatdku/core/tools/syllabi/create_table.sql diff --git a/chatdku/core/tools/syllabi_tool/curriculum_schema.json b/chatdku/core/tools/syllabi/curriculum_schema.json similarity index 100% rename from chatdku/core/tools/syllabi_tool/curriculum_schema.json rename to chatdku/core/tools/syllabi/curriculum_schema.json diff --git a/chatdku/core/tools/syllabi_tool/example.py b/chatdku/core/tools/syllabi/example.py similarity index 100% rename from chatdku/core/tools/syllabi_tool/example.py rename to chatdku/core/tools/syllabi/example.py diff --git a/chatdku/core/tools/syllabi_tool/generate_sql.py b/chatdku/core/tools/syllabi/generate_sql.py similarity index 97% rename from chatdku/core/tools/syllabi_tool/generate_sql.py rename to chatdku/core/tools/syllabi/generate_sql.py index 5e093b1e..e6479b39 100644 --- a/chatdku/core/tools/syllabi_tool/generate_sql.py +++ b/chatdku/core/tools/syllabi/generate_sql.py @@ -120,7 +120,7 @@ def _collapse_repeated_lines(text: str, max_consecutive: int = 4) -> str: if prev is not None: if count > max_consecutive: out_lines.append(prev) - out_lines.append(f"...({count+1} repeated lines collapsed)...") + out_lines.append(f"...({count + 1} repeated lines collapsed)...") else: out_lines.extend([prev] * (count + 1)) prev = line @@ -130,7 +130,7 @@ def _collapse_repeated_lines(text: str, max_consecutive: int = 4) -> str: if prev is not None: if count > max_consecutive: out_lines.append(prev) - out_lines.append(f"...({count+1} repeated lines collapsed)...") + out_lines.append(f"...({count + 1} repeated lines collapsed)...") else: out_lines.extend([prev] * (count + 1)) diff --git a/chatdku/core/tools/syllabi_tool/get_schema.py b/chatdku/core/tools/syllabi/get_schema.py similarity index 100% rename from chatdku/core/tools/syllabi_tool/get_schema.py rename to chatdku/core/tools/syllabi/get_schema.py diff --git a/chatdku/core/tools/syllabi_tool/reqs_syllabi_agent b/chatdku/core/tools/syllabi/reqs_syllabi_agent similarity index 100% rename from chatdku/core/tools/syllabi_tool/reqs_syllabi_agent rename to chatdku/core/tools/syllabi/reqs_syllabi_agent diff --git a/chatdku/core/tools/syllabi_tool/run_local_ingest.sh b/chatdku/core/tools/syllabi/run_local_ingest.sh similarity index 100% rename from chatdku/core/tools/syllabi_tool/run_local_ingest.sh rename to chatdku/core/tools/syllabi/run_local_ingest.sh diff --git a/chatdku/core/tools/syllabi_tool/query_curriculum_db.py b/chatdku/core/tools/syllabi/syllabi_tool.py similarity index 89% rename from chatdku/core/tools/syllabi_tool/query_curriculum_db.py rename to chatdku/core/tools/syllabi/syllabi_tool.py index 0f15c52c..06294aad 100644 --- a/chatdku/core/tools/syllabi_tool/query_curriculum_db.py +++ b/chatdku/core/tools/syllabi/syllabi_tool.py @@ -8,26 +8,25 @@ from opentelemetry.trace import Status, StatusCode from chatdku.core.tools.retriever.base_retriever import NodeWithScore, nodes_to_OTLP -from chatdku.core.tools.syllabi_tool.generate_sql import GenerateSQL +from chatdku.core.tools.syllabi.generate_sql import GenerateSQL from chatdku.core.utils import span_ctx_start from chatdku.setup import DB table_name = "curriculum" -def QueryCurriculumOuter(N=3): +def SyllabusLookupOuter(N=3): db = DB() sql_agent = GenerateSQL() db_schema = fetch_schema(db=db) - def QueryCurriculum(query: str, current_user_message: str) -> tuple[str, dict]: + def SyllabusLookup(query: str, current_user_message: str) -> tuple[str, dict]: """ - Takes a natural language query about courses and classes offered - at Duke Kunshan University -> generates intermediate SQL query + Takes a natural language query about course syllabus -> generates intermediate SQL query passed into Postgres which has courses' syllabi -> Result formatted in natural language. It can answer what a specific course covers, what kind of assignments - are given, a course's grading policy, and a course's history (when it was offered). + are given, and a course's grading policy. Good tool for syllabus questions. @@ -95,7 +94,7 @@ def QueryCurriculum(query: str, current_user_message: str) -> tuple[str, dict]: return answer, internal_result - return QueryCurriculum + return SyllabusLookup def fetch_schema(db: DB) -> str: diff --git a/chatdku/core/tools/syllabi_tool/update_db.py b/chatdku/core/tools/syllabi/update_db.py similarity index 100% rename from chatdku/core/tools/syllabi_tool/update_db.py rename to chatdku/core/tools/syllabi/update_db.py diff --git a/chatdku/core/tui.py b/chatdku/core/tui.py new file mode 100644 index 00000000..143280e6 --- /dev/null +++ b/chatdku/core/tui.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +"""Terminal UI for the ChatDKU agent. + +Run: + python -m chatdku.core.tui +""" + +from __future__ import annotations + +import asyncio +from collections import deque +from pathlib import Path + +import pyfiglet +from rich.table import Table +from rich.text import Text +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import VerticalScroll +from textual.widgets import Footer, Header, Input, Static + +from chatdku.core.agent import build_agent + +_LOGO_PATH = Path(__file__).parent / "ascii_logo" + + +def _build_splash() -> Table: + """Build the startup splash: ANSI logo beside a figlet 'ChatDKU' title.""" + try: + logo_text = Text.from_ansi(_LOGO_PATH.read_text()) + except OSError: + logo_text = Text("") + + figlet_str = "\n" * 4 + pyfiglet.figlet_format("ChatDKU", font="slant") + title_text = Text(figlet_str, style="bold #4aa7ff") + + grid = Table.grid(padding=(0, 2)) + grid.add_column(no_wrap=True) + grid.add_column(no_wrap=True) + grid.add_row(logo_text, title_text) + return grid + + +class Message(Static): + """A single chat bubble with a rounded, color-coded border.""" + + DEFAULT_CSS = """ + Message { + border: round #3a3f4b; + background: transparent; + padding: 0 1; + margin: 0 2; + width: auto; + max-width: 90%; + height: auto; + } + Message.user { border: round #7fe684; color: #d6f5d6; } + Message.agent { border: round #4aa7ff; color: #d6e6ff; } + Message.system { border: round #5c616d; color: #9aa0ab; } + Message.pending { border: round #4a4f5a; color: #7c8290; } + """ + + def __init__(self, role: str, content: str) -> None: + super().__init__(self._format(role, content), markup=False) + self.role = role + self.add_class(role) + + @staticmethod + def _format(role: str, content: str) -> str: + label = { + "user": "You", + "agent": "ChatDKU", + "system": "System", + "pending": "ChatDKU", + }.get(role, role) + return f"[{label}]\n{content}" + + def update_content(self, role: str, content: str) -> None: + self.update(self._format(role, content)) + + +class ChatDKUApp(App): + ENABLE_COMMAND_PALETTE = True + COLOR_SYSTEM = "truecolor" + + CSS = """ + Screen { layout: vertical; background: transparent; } + Header { background: #1a1d23; color: #c7cbd4; } + Footer { background: #1a1d23; color: #8a909c; } + #log { height: 1fr; background: transparent; } + #input { + dock: bottom; + margin: 0 1 1 1; + border: round #3a3f4b; + background: transparent; + color: #d6dae2; + } + #input:focus { border: round #7ab7ff; } + """ + + BINDINGS = [ + Binding("ctrl+c", "quit", "Quit", priority=True), + Binding("ctrl+l", "clear", "Clear"), + ] + + def __init__(self) -> None: + super().__init__() + self.agent = None # built lazily in a worker + self.queue: deque[str] = deque() + self.busy = False + + def compose(self) -> ComposeResult: + yield Header(show_clock=False) + yield VerticalScroll(id="log") + yield Input( + placeholder="Ask about DKU… (Enter to send, Ctrl+C to quit)", id="input" + ) + yield Footer() + + async def on_mount(self) -> None: + self.title = "ChatDKU" + self.sub_title = "TUI" + await self._post("system", "Booting agent… (this may take a few seconds)") + self.query_one("#input", Input).focus() + self.run_worker(self._boot, thread=True, exclusive=True, group="boot") + + def _boot(self) -> None: + self.agent = build_agent(streaming=False) + self.call_from_thread(self._boot_done) + + async def _boot_done(self) -> None: + log = self.query_one("#log", VerticalScroll) + await log.mount(Static(_build_splash())) + log.scroll_end(animate=False) + await self._post("system", "Ready.") + + async def _post(self, role: str, content: str) -> Message: + msg = Message(role, content) + log = self.query_one("#log", VerticalScroll) + await log.mount(msg) + log.scroll_end(animate=False) + return msg + + async def on_input_submitted(self, event: Input.Submitted) -> None: + text = event.value.strip() + if not text: + return + event.input.value = "" + await self._post("user", text) + self.queue.append(text) + if not self.busy: + await self._drain() + + async def _drain(self) -> None: + while self.queue: + query = self.queue.popleft() + self.busy = True + pending = await self._post("pending", "thinking…") + # Wait for boot before answering. + while self.agent is None: + await asyncio.sleep(0.1) + loop = asyncio.get_running_loop() + try: + answer = await loop.run_in_executor(None, self._run_agent, query) + except Exception as e: + answer = f"[error] {e}" + pending.update_content("agent", answer) + pending.remove_class("pending") + pending.add_class("agent") + self.query_one("#log", VerticalScroll).scroll_end(animate=False) + self.busy = False + + def _run_agent(self, query: str) -> str: + result = self.agent(current_user_message=query) + response = result.response + if isinstance(response, str): + return response + return "".join(response) + + async def action_clear(self) -> None: + log = self.query_one("#log", VerticalScroll) + await log.remove_children() + + +def main() -> None: + ChatDKUApp().run() + + +if __name__ == "__main__": + main() diff --git a/chatdku/django/chatdku_django/chat/tools.py b/chatdku/django/chatdku_django/chat/tools.py index 98a5bd0a..1f757788 100644 --- a/chatdku/django/chatdku_django/chat/tools.py +++ b/chatdku/django/chatdku_django/chat/tools.py @@ -1,5 +1,8 @@ -from chatdku.core.tools.llama_index_tools import KeywordRetrieverOuter, VectorRetrieverOuter -from chatdku.core.tools.syllabi_tool.query_curriculum_db import QueryCurriculumOuter +from chatdku.core.tools.llama_index_tools import ( + KeywordRetrieverOuter, + VectorRetrieverOuter, +) +from chatdku.core.tools.syllabi.syllabi_tool import SyllabusLookupOuter def get_tools(user_id: str, search_mode, docs): @@ -21,7 +24,7 @@ def get_tools(user_id: str, search_mode, docs): search_mode=search_mode, files=docs, ), - QueryCurriculumOuter(), + SyllabusLookupOuter(), ] return base_tools diff --git a/chatdku/django/chatdku_django/chat/views.py b/chatdku/django/chatdku_django/chat/views.py index aed58204..8a696830 100644 --- a/chatdku/django/chatdku_django/chat/views.py +++ b/chatdku/django/chatdku_django/chat/views.py @@ -30,12 +30,8 @@ from rest_framework.views import APIView from chatdku.core.agent import Agent -from chatdku.core.tools.llama_index_tools import KeywordRetrieverOuter, VectorRetrieverOuter -from chatdku.core.tools.syllabi_tool.query_curriculum_db import QueryCurriculumOuter from chat.tools import get_tools -from rest_framework.views import APIView -from rest_framework.response import Response from datetime import datetime from .models import WeeklyEvent @@ -396,31 +392,41 @@ def destroy(self, request, *args, **kwargs): return super().destroy(request, *args, **kwargs) - class WeeklyEventsView(APIView): - permission_classes = [IsAuthenticated] + permission_classes = [IsAuthenticated] def get(self, request): - start_date = request.query_params.get('start_date') - end_date = request.query_params.get('end_date') + start_date = request.query_params.get("start_date") + end_date = request.query_params.get("end_date") if not start_date or not end_date: - return Response({'error': 'Missing start_date or end_date'}, status=400) + return Response({"error": "Missing start_date or end_date"}, status=400) try: - start = datetime.strptime(start_date, '%Y-%m-%d').date() - end = datetime.strptime(end_date, '%Y-%m-%d').date() + start = datetime.strptime(start_date, "%Y-%m-%d").date() + end = datetime.strptime(end_date, "%Y-%m-%d").date() except ValueError: - return Response({'error': 'Invalid date format, use YYYY-MM-DD'}, status=400) - - events = WeeklyEvent.objects.using("ingestion").filter(event_date__range=(start, end)).order_by('event_date', 'start_time') - data = [{ - 'title': e.title, - 'date': e.event_date.isoformat(), - 'start_time': e.start_time.strftime('%H:%M:%S') if e.start_time else None, - 'end_time': e.end_time.strftime('%H:%M:%S') if e.end_time else None, - 'location': e.location, - 'sponsor': e.sponsor, - 'open_to': e.open_to, - 'speaker': e.speaker, - 'url': e.url, - } for e in events] - return Response({'events': data}) \ No newline at end of file + return Response( + {"error": "Invalid date format, use YYYY-MM-DD"}, status=400 + ) + + events = ( + WeeklyEvent.objects.using("ingestion") + .filter(event_date__range=(start, end)) + .order_by("event_date", "start_time") + ) + data = [ + { + "title": e.title, + "date": e.event_date.isoformat(), + "start_time": ( + e.start_time.strftime("%H:%M:%S") if e.start_time else None + ), + "end_time": e.end_time.strftime("%H:%M:%S") if e.end_time else None, + "location": e.location, + "sponsor": e.sponsor, + "open_to": e.open_to, + "speaker": e.speaker, + "url": e.url, + } + for e in events + ] + return Response({"events": data}) diff --git a/chatdku/django/chatdku_django/core/views.py b/chatdku/django/chatdku_django/core/views.py index 59960448..8c14f499 100644 --- a/chatdku/django/chatdku_django/core/views.py +++ b/chatdku/django/chatdku_django/core/views.py @@ -16,9 +16,7 @@ extend_schema, OpenApiResponse, ) -from core.tasks import update_user_chroma from .utils import slugify -from rest_framework import status import logging diff --git a/chatdku/ingestion/major_ingest.py b/chatdku/ingestion/major_ingest.py index 1c99d1d9..6de6ae49 100644 --- a/chatdku/ingestion/major_ingest.py +++ b/chatdku/ingestion/major_ingest.py @@ -98,10 +98,13 @@ def extract_majors( ) def make_major_pattern(major: str) -> re.Pattern: - return re.compile( - rf"{re.escape(major)}", - re.IGNORECASE, - ) + pattern = re.compile(rf"{re.escape(major)}", re.IGNORECASE) + # NOTE: Because of DKU's inconsistent formatting + # We have to create an exception for Data Science + if major == "Data Science": + pattern = re.compile(r"Data Science\s+Divisional") + + return pattern for page_num in range(page_start - 1, len(doc)): page = doc[page_num] @@ -135,9 +138,9 @@ def make_major_pattern(major: str) -> re.Pattern: def sanitize_filename(name: str) -> str: """Convert a major name to a safe filename.""" # Remove or replace unsafe characters - safe = re.sub(r"[^\w\s-]", "", name) + safe = re.sub(r"[^\w\s-]", "-", name) safe = re.sub(r"[-\s]+", "-", safe) - return safe.strip("-").lower() + return safe.strip().lower() def save_major_content(major_name: str, content: Dict, output_dir: Path): diff --git a/devsync.sh b/devsync.sh index e2b4c083..7eb77c9e 100644 --- a/devsync.sh +++ b/devsync.sh @@ -36,11 +36,20 @@ fi TARGET="${1:-}" if [[ -n "$TARGET" ]]; then - if [[ "$TARGET" != *"/"* && "$TARGET" != *.py ]]; then - # Looks like a module (e.g. chatdku.core.agent) — run with -m + if [[ "$TARGET" == *.py || "$TARGET" == */* ]]; then + : # file path — handled below + elif [[ "$TARGET" =~ ^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$ ]]; then + # Valid Python module name (e.g. chatdku.core.agent) — run with -m REMOTE_RUN_CMD="uv run python -m $(printf %q "$TARGET")" RUN_DESC="python -m $TARGET" + TARGET="" # signal: already handled else + # Anything else — treat as a natural-language query for the agent. + REMOTE_RUN_CMD="uv run python -m chatdku.core.agent $(printf %q "$TARGET")" + RUN_DESC="agent query: $TARGET" + TARGET="" + fi + if [[ -n "$TARGET" ]]; then # Treat as a file path if [[ "$TARGET" = /* ]]; then TARGET="${TARGET#"$LOCAL_DIR"/}" @@ -59,16 +68,13 @@ fi step "preparing remote directory $REMOTE_DIR on $SERVER" ssh "${SERVER}" "mkdir -p ${REMOTE_DIR}" -step "linking ~/.env → ${REMOTE_DIR}/.env" -ssh "${SERVER}" ' - if [ -f ~/.env ]; then - ln -sf ~/.env '"${REMOTE_DIR}"'/.env - else - echo "WARN: ~/.env not found on server — skipping link" - fi -' -if ssh "${SERVER}" '[ ! -f '"${REMOTE_DIR}"'/.env ]'; then - warn "no .env in ${REMOTE_DIR} — the agent may fail to start" +# Secrets are loaded automatically from /datapool/secrets/chatdku_env.sh via +# /etc/profile.d/chatdku.sh for all chatdku_devs group members. +# No ~/.env file is needed or expected — see Documentations/Shared-Secrets.md. +step "verifying shared secrets are loaded on $SERVER" +if ! ssh "${SERVER}" 'bash -l -c "[ -n \"\${REDIS_HOST:-}\" ]"' 2>/dev/null; then + warn "shared secrets not loaded on $SERVER — are you in the chatdku_devs group?" + warn "Run: groups | grep chatdku_devs (if missing, ask an admin to run add_user.sh)" fi info "syncing ${BOLD}$LOCAL_DIR${RESET}${CYAN} → ${BOLD}$SERVER:$REMOTE_DIR" diff --git a/pyproject.toml b/pyproject.toml index ff019d06..1c189f03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,14 +33,16 @@ classifiers = [ dependencies = [ # core - "llama-index~=0.13.1", - "llama-index-storage-docstore-redis~=0.4.0", - "llama-index-vector-stores-redis~=0.6.0", - "llama-index-embeddings-text-embeddings-inference~=0.4.0", - "llama-index-llms-llama-cpp~=0.5.0", - "llama-index-retrievers-bm25~=0.6.2", + "llama-index~=0.14.20", + "llama-index-storage-docstore-redis~=0.5.0", + "llama-index-vector-stores-redis~=0.8.0", + "llama-index-embeddings-text-embeddings-inference~=0.5.0", + "llama-index-llms-llama-cpp~=0.6.0", + "llama-index-retrievers-bm25~=0.7.1", + "llama-index-readers-file~=0.6.0", "nltk>=3.9", "chromadb~=1.0.15", + "redis~=5.2.1", "arize-phoenix~=13.21.0", "arize-phoenix-evals==2.13.0", "opentelemetry-api~=1.41.0", @@ -52,16 +54,22 @@ dependencies = [ "tokenizers>=0.21.0", "docx2txt", "python-pptx", - "pymupdf", + "pymupdf", # for major ingestion + "pymupdf4llm", # needed for pdf -> md + "thefuzz", # for fuzzy searching major names in major ret "python-docx", "pdfplumber", - "redis~=5.2.1", "psycopg2-binary>=2.9.10", "pandas~=2.2.3", "dotenv>=0.9.9", "sqlalchemy", + # Development dependencies "pytest", "pre-commit", + # TUI + "textual>=0.83.0", + "pyfiglet>=1.0.2", + "thefuzz>=0.22.1", ] [project.optional-dependencies] @@ -97,3 +105,8 @@ packages = ["chatdku"] [tool.flake8] ignore = ["E203","W503"] max-line-length = 120 + +[dependency-groups] +dev = [ + "flake8>=7.3.0", +] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..d6083818 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,218 @@ +"""Shared fixtures for ChatDKU tool tests.""" + +from contextlib import contextmanager +from unittest.mock import MagicMock + +import pandas as pd +import pytest + + +@pytest.fixture() +def mock_span_ctx(monkeypatch): + """Mock span_ctx_start so no real tracer/Phoenix is needed. + + Patches at every import site since each tool module binds the name at import time. + Returns the mock span for assertions on set_attributes / set_status. + """ + mock_span = MagicMock() + + @contextmanager + def fake_span_ctx_start(name, kind, parent_context=None): + yield mock_span + + targets = [ + "chatdku.core.utils.span_ctx_start", + "chatdku.core.tools.course_schedule.span_ctx_start", + "chatdku.core.tools.course_recommender.span_ctx_start", + "chatdku.core.tools.get_prerequisites.span_ctx_start", + "chatdku.core.tools.major_requirements.span_ctx_start", + "chatdku.core.tools.syllabi.query_curriculum_db.span_ctx_start", + "chatdku.core.tools.retriever.base_retriever.span_ctx_start", + ] + for target in targets: + try: + monkeypatch.setattr(target, fake_span_ctx_start) + except (AttributeError, ImportError): + pass # module not yet imported — safe to skip + + return mock_span + + +@pytest.fixture() +def mock_get_current_span(monkeypatch): + """Mock get_current_span for llama_index_tools which uses it directly.""" + mock_span = MagicMock() + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.get_current_span", lambda: mock_span + ) + return mock_span + + +@pytest.fixture() +def sample_classdata_csv(tmp_path): + """Create a temporary class schedule CSV with representative data.""" + csv_path = tmp_path / "classdata.csv" + df = pd.DataFrame( + { + "Subject": ["COMPSCI", "COMPSCI", "MATH", "BIOL", "CHINESE"], + "Catalog": ["101", "201", "201", "305", "101A"], + "Section": ["01", "01", "01", "01", "01"], + "Component": ["LEC", "LEC", "LEC", "LAB", "LEC"], + "Instructor": [ + "Alice Smith", + "Bob Jones", + "Carol Lee", + "Dave Kim", + "Eve Wu", + ], + "Days": ["MWF", "TTh", "MWF", "TTh", "MWF"], + "Start Time": ["09:00", "10:30", "11:00", "14:00", "13:00"], + "End Time": ["09:50", "11:45", "11:50", "15:15", "13:50"], + "Enrollment": [30, 25, 40, 15, 20], + } + ) + df.to_csv(csv_path, index=False) + return str(csv_path) + + +@pytest.fixture() +def sample_classdata_real_csv(tmp_path): + """Classdata CSV matching the actual cleaned_classdata.csv column layout. + + Uses Mon/Tues/Wed/Thurs/Fri boolean columns and Mtg Start/Mtg End for times, + matching what clean_classdata.py produces on the server. + """ + csv_path = tmp_path / "classdata_real.csv" + df = pd.DataFrame( + { + "Course ID": [1001, 1002, 1003, 1004, 1005], + "Term": [2268] * 5, + "Session": ["7W1", "7W1", "7W2", "7W1", "7W2"], + "Section": ["001", "001", "001", "001", "001"], + "Subject": ["COMPSCI", "MATH", "MATH", "STATS", "GLOCHALL"], + "Catalog": ["201", "201", "202", "302", "201"], + "Descr": [ + "Intro to Programming and Data Structures", + "Multivariable Calculus", + "Linear Algebra", + "Principles of Machine Learning", + "Global Challenges", + ], + "Class Nbr": [100, 101, 102, 103, 104], + "Enrollment Status": ["Open"] * 5, + "Class Status": ["Active"] * 5, + "Enrollment Capacity": [40] * 5, + "Wait List Capacity": [8] * 5, + "Enrollment Total": [20] * 5, + "Wait List Total": [0] * 5, + "Seats Open": ["20/40"] * 5, + "Waitlist Open": ["8/8"] * 5, + "Attributes": [""] * 5, + "Prgrss Unt": [4.0] * 5, + "Grading": ["GRD"] * 5, + "Start Date": ["08/25/2026"] * 5, + "End Date": ["10/08/2026"] * 5, + "Mtg Start": [ + "9:00:00.000000AM", + "10:00:00.000000AM", + "2:00:00.000000PM", + "8:00:00.000000AM", + "9:00:00.000000AM", + ], + "Mtg End": [ + "9:50:00.000000AM", + "10:50:00.000000AM", + "2:50:00.000000PM", + "8:50:00.000000AM", + "9:50:00.000000AM", + ], + "Mon": ["Y", "Y", "N", "Y", "Y"], + "Tues": ["N", "N", "Y", "N", "N"], + "Wed": ["Y", "Y", "N", "Y", "Y"], + "Thurs": ["N", "N", "Y", "N", "N"], + "Fri": ["Y", "Y", "N", "Y", "Y"], + "Room No": ["IB1001"] * 5, + "Instructor": [ + "Smith,Alice", + "Jones,Bob", + "Lee,Carol", + "Kim,Dave", + "Wu,Eve", + ], + } + ) + df.to_csv(csv_path, index=False) + return str(csv_path) + + +@pytest.fixture() +def sample_prereq_csv(tmp_path): + """Create a temporary prerequisites CSV with UTF-16LE encoding. + + Column layout matches positional access in get_prereq: + col 0: ID + col 1: Effective Date (MM/DD/YYYY) + col 2: Subject + col 3: Catalog + cols 4-12: padding + col 13: Description (prerequisite text) + """ + csv_path = tmp_path / "prereq.csv" + rows = [ + # COMPSCI 201 with prereqs, two rows with different dates + [ + 1, + "01/15/2023", + "COMPSCI", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prereq: COMPSCI 101", + ], + [ + 2, + "09/01/2024", + "COMPSCI", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prereq: COMPSCI 101 or COMPSCI 102", + ], + # MATH 201 with prereqs + [ + 3, + "03/10/2024", + "MATH", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prereq: MATH 101", + ], + # BIOL 305 with empty description + [4, "06/01/2024", "BIOL", "305", "", "", "", "", "", "", "", "", "", ""], + ] + columns = [f"col{i}" for i in range(14)] + df = pd.DataFrame(rows, columns=columns) + df.to_csv(csv_path, index=False, encoding="utf-16le") + return str(csv_path) diff --git a/tests/test_agent_configuration.py b/tests/test_agent_configuration.py new file mode 100644 index 00000000..9ea60f69 --- /dev/null +++ b/tests/test_agent_configuration.py @@ -0,0 +1,189 @@ +"""Tests that verify the Planner + Executor are configured correctly for +policy-aware, agenda-extending course recommendations. + +These are structural/configuration tests — they do not invoke any LLM or +external services. +""" + +import pytest + +from chatdku.core.dspy_classes.executor import AssessSignature, Executor +from chatdku.core.dspy_classes.plan import PLANNER_DEMOS, PlannerSignature + + +# --------------------------------------------------------------------------- +# Planner configuration tests +# --------------------------------------------------------------------------- + + +class TestPlannerConfiguration: + """Verify the Planner is configured for policy-first course planning.""" + + def test_planner_signature_has_available_tools_field(self): + """Planner must expose tool descriptions so it can include them in plans.""" + fields = PlannerSignature.input_fields + assert "available_tools" in fields + + def test_planner_instructions_require_policy_retrieval_before_recommender(self): + """Planner docstring must instruct: retrieve policies FIRST, then CourseRecommender.""" + instructions = PlannerSignature.__doc__ + assert instructions is not None + # Should mention policy/year retrieval + assert any( + kw in instructions.lower() + for kw in ("policy", "policies", "mandatory courses", "year-specific") + ), "Planner must instruct Executor to retrieve year-specific policies" + # Should still mention CourseRecommender + assert ( + "CourseRecommender" in instructions + ), "Planner must still reference CourseRecommender as the baseline tool" + + def test_planner_instructions_mention_vector_or_keyword_retriever_for_policies( + self, + ): + """Planner must name VectorRetriever or KeywordRetriever for policy lookup.""" + instructions = PlannerSignature.__doc__ or "" + assert ( + "VectorRetriever" in instructions or "KeywordRetriever" in instructions + ), "Planner must instruct use of VectorRetriever/KeywordRetriever for policy lookup" + + def test_planner_missing_info_demo_asks_for_all_three(self): + """The 'missing info' demo must ask for major, year, and completed courses.""" + missing_info_demo = next( + (d for d in PLANNER_DEMOS if d.action_type == "send_message"), None + ) + assert ( + missing_info_demo is not None + ), "PLANNER_DEMOS must have a send_message example" + action = missing_info_demo.action.lower() + assert "major" in action, "Missing-info message must ask for major" + assert any( + kw in action for kw in ("year", "matriculation", "class of") + ), "Missing-info message must ask for year of matriculation" + assert any( + kw in action for kw in ("completed", "taken", "taking") + ), "Missing-info message must ask for completed courses" + + def test_planner_schedule_demo_is_policy_first(self): + """The full schedule planning demo must mention policy retrieval before CourseRecommender.""" + # Find the demo that has all three pieces of info (plan action, mentions Data Science) + plan_demos = [d for d in PLANNER_DEMOS if d.action_type == "plan"] + schedule_demo = next( + (d for d in plan_demos if "Class of" in d.current_user_message), None + ) + assert ( + schedule_demo is not None + ), "PLANNER_DEMOS must include a complete schedule planning example with class year" + action = schedule_demo.action + # Policy step should appear before CourseRecommender call in the action text + policy_keywords = [ + "policy", + "policies", + "mandatory", + "retrieve", + "VectorRetriever", + "KeywordRetriever", + "year-specific", + "requirements", + ] + has_policy_step = any(kw.lower() in action.lower() for kw in policy_keywords) + assert ( + has_policy_step + ), f"Schedule planning demo must include a policy-retrieval step.\nAction:\n{action}" + policy_pos = min( + ( + action.lower().find(kw.lower()) + for kw in policy_keywords + if kw.lower() in action.lower() + ), + default=-1, + ) + recommender_pos = action.find("CourseRecommender") + assert recommender_pos != -1, "Demo must mention CourseRecommender" + assert ( + policy_pos < recommender_pos + ), "Policy retrieval step must come BEFORE CourseRecommender in the demo plan" + + +# --------------------------------------------------------------------------- +# Executor configuration tests +# --------------------------------------------------------------------------- + + +class TestExecutorConfiguration: + """Verify the Executor supports dynamic agenda extensions.""" + + def test_assess_signature_has_agenda_extensions_output(self): + """AssessSignature must have an agenda_extensions output field.""" + output_fields = AssessSignature.output_fields + assert "agenda_extensions" in output_fields, ( + "AssessSignature must have agenda_extensions output to communicate " + "new investigation areas discovered during execution" + ) + + def test_assess_signature_has_current_agenda_not_plan(self): + """AssessSignature uses 'current_agenda' (not 'plan') as the input field name.""" + input_fields = AssessSignature.input_fields + assert "current_agenda" in input_fields, ( + "AssessSignature must use 'current_agenda' (not 'plan') to indicate " + "the agenda can grow beyond the original plan" + ) + assert ( + "plan" not in input_fields + ), "AssessSignature must not use the old 'plan' field name" + + def test_assess_signature_decision_field_exists_as_output(self): + """AssessSignature must have a 'decision' output field.""" + assert ( + "decision" in AssessSignature.output_fields + ), "AssessSignature must define a 'decision' output field" + # Verify it is not accidentally an input field. + assert "decision" not in AssessSignature.input_fields + + def test_assess_signature_docstring_mentions_agenda_extensions(self): + """AssessSignature docstring must describe discovering new investigation areas.""" + doc = AssessSignature.__doc__ or "" + assert any( + kw in doc.lower() + for kw in ( + "new investigation", + "agenda extension", + "revealed", + "discovered", + ) + ), "AssessSignature docstring must explain when to extend the agenda" + + def test_executor_get_token_limits_accepts_current_agenda(self, tmp_path): + """Executor.get_token_limits must accept 'current_agenda' (not 'plan').""" + + # Build a minimal Executor with a trivial tool. + def dummy_tool(query: str) -> str: + """A dummy tool for testing. Args: query (str): The query.""" + return "dummy" + + try: + executor = Executor([dummy_tool], max_iterations=1) + # Should not raise — the old code used 'plan' which would break here. + limits = executor.get_token_limits( + current_agenda="test plan", + current_user_message="test", + conversation_history="", + conversation_summary="", + trajectory="", + assessment="", + ) + assert isinstance(limits, dict) + assert len(limits) > 0 + except Exception as e: + pytest.fail(f"get_token_limits with current_agenda raised: {e}") + + def test_executor_max_iterations_default_is_five_or_more(self): + """Executor default max_iterations should be >= 5 for policy-aware planning.""" + import inspect + + sig = inspect.signature(Executor.__init__) + default = sig.parameters["max_iterations"].default + assert default >= 5, ( + f"Executor default max_iterations is {default}, expected >= 5 for policy-first plans " + "(policy retrieval + CourseRecommender + potential extensions require more iterations)" + ) diff --git a/tests/test_course_recommender.py b/tests/test_course_recommender.py new file mode 100644 index 00000000..436be77e --- /dev/null +++ b/tests/test_course_recommender.py @@ -0,0 +1,657 @@ +"""Tests for chatdku.core.tools.course_recommender. + +Covers: + - parse_course_codes (unit) + - prerequisites_met (unit) + - CourseRecommenderOuter / full integration (4 simple + complex scenarios) +""" + +import pandas as pd +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.course_recommender import ( + CourseRecommenderOuter, + parse_course_codes, + prerequisites_met, +) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def req_dir(tmp_path): + """Requirements dir that mirrors the real Data Science major layout.""" + (tmp_path / "data-science.md").write_text( + """# Data Science + +## Interdisciplinary Courses +| Course Code | Course Name | Credits | +|-------------|-------------|---------| +| COMPSCI 201 | Intro to Programming and Data Structures | 4 | +| STATS 302 | Principles of Machine Learning | 4 | +| STATS 303 | Statistical Machine Learning | 4 | +| STATS 401 | Data Acquisition and Visualization | 4 | + +## Disciplinary Courses +| Course Code | Course Name | Credits | +|-------------|-------------|---------| +| MATH 201 | Multivariable Calculus | 4 | +| MATH 202 | Linear Algebra | 4 | +| MATH 206 | Probability and Statistics | 4 | +| COMPSCI 301 | Algorithms and Databases | 4 | +""" + ) + (tmp_path / "requirements-for-all-majors.md").write_text( + """# Requirements for All Majors + +## Common Core +| Academic Year | Course Code | Course Name | Credits | +|---------------|-------------|-------------|---------| +| First Year | GCHINA 101 | China in the World | 4 | +| Second Year | GLOCHALL 201 | Global Challenges | 4 | +| Third Year | ETHLDR 201 | Ethics and Citizenship | 4 | +""" + ) + return str(tmp_path) + + +@pytest.fixture() +def prereq_csv_with_ds_courses(tmp_path): + """Prerequisite CSV with realistic Data Science course prerequisites. + + Mirrors real data from DKUHub: + - MATH 201: "Prerequisite: MATH 101 or MATH 105" + - STATS 302: "Prerequisite: MATH 201, MATH 202, MATH 206, and COMPSCI 201" + - COMPSCI 301: "COMPSCI 201, Anti-requisites: COMPSCI 308 and 310" + - GLOCHALL 201: "Prerequisite: GCHINA 101 and sophomore standing" + - ETHLDR 201: "Prerequisite: GLOCHALL 201 and junior standing" + - All others: no entry (treated as no prerequisites) + """ + csv_path = tmp_path / "prereq_ds.csv" + rows = [ + [ + 1, + "09/01/2024", + "MATH", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prerequisite: MATH 101 or MATH 105", + ], + [ + 2, + "09/01/2024", + "STATS", + "302", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prerequisite: MATH 201, MATH 202, MATH 206, and COMPSCI 201. Anti-requisite: MATH 405", + ], + [ + 3, + "09/01/2024", + "COMPSCI", + "301", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "COMPSCI 201, Anti-requisites: COMPSCI 308 and 310", + ], + [ + 4, + "09/01/2024", + "GLOCHALL", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prerequisite: GCHINA 101 and sophomore standing", + ], + [ + 5, + "09/01/2024", + "ETHLDR", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Prerequisite: GLOCHALL 201 and junior standing", + ], + ] + columns = [f"col{i}" for i in range(14)] + df = pd.DataFrame(rows, columns=columns) + df.to_csv(csv_path, index=False, encoding="utf-16le") + return str(csv_path) + + +# --------------------------------------------------------------------------- +# Unit tests: parse_course_codes +# --------------------------------------------------------------------------- + + +class TestParseCourseCode: + def test_extracts_standard_table_entry(self): + text = "| COMPSCI 201 | Intro to Programming | 4 |" + assert "COMPSCI 201" in parse_course_codes(text) + + def test_extracts_courses_from_bullet_list(self): + text = "- STATS 302: Principles of Machine Learning\n- MATH 201: Calculus" + codes = parse_course_codes(text) + assert "STATS 302" in codes + assert "MATH 201" in codes + + def test_ignores_unknown_subject_prefixes(self): + # "THE 123" and "FOR 999" are not DKU subject codes → should be filtered + text = "THE 123 and FOR 999 are not DKU courses. But COMPSCI 101 is." + codes = parse_course_codes(text) + assert "COMPSCI 101" in codes + assert "THE 123" not in codes + assert "FOR 999" not in codes + + def test_deduplicates_repeated_codes(self): + text = "MATH 201 is listed here. Also MATH 201 appears again." + codes = parse_course_codes(text) + assert codes.count("MATH 201") == 1 + + def test_extracts_glochall_subject(self): + text = "| GLOCHALL 201 | Global Challenges | 4 |" + assert "GLOCHALL 201" in parse_course_codes(text) + + def test_handles_empty_text(self): + assert parse_course_codes("") == [] + + def test_preserves_insertion_order(self): + text = "COMPSCI 201 then STATS 302 then MATH 201" + codes = parse_course_codes(text) + assert ( + codes.index("COMPSCI 201") + < codes.index("STATS 302") + < codes.index("MATH 201") + ) + + def test_does_not_match_two_digit_catalog(self): + # Catalog numbers must be 3 digits — "MATH 20" should not match + assert "MATH 20" not in parse_course_codes("MATH 20 is not a real code") + + def test_matches_catalog_with_letter_suffix(self): + assert "CHINESE 101A" in parse_course_codes("CHINESE 101A") + + +# --------------------------------------------------------------------------- +# Unit tests: prerequisites_met +# --------------------------------------------------------------------------- + + +class TestPrerequisitesMet: + @pytest.fixture(autouse=True) + def _prereq_df(self, prereq_csv_with_ds_courses): + self.prereq_df = pd.read_csv(prereq_csv_with_ds_courses, encoding="utf-16le") + + def test_no_prereq_entry_returns_eligible(self): + """STATS 401 is not in the prereq CSV — should be eligible with no prereqs.""" + met, reason = prerequisites_met("STATS 401", set(), self.prereq_df) + assert met is True + assert reason == "" + + def test_simple_or_prereq_satisfied_by_first_option(self): + """MATH 201 needs MATH 101 or MATH 105. MATH 101 completed → eligible.""" + completed = {"MATH 101"} + met, reason = prerequisites_met("MATH 201", completed, self.prereq_df) + assert met is True + + def test_simple_or_prereq_satisfied_by_second_option(self): + """MATH 201 needs MATH 101 or MATH 105. MATH 105 completed → eligible.""" + completed = {"MATH 105"} + met, reason = prerequisites_met("MATH 201", completed, self.prereq_df) + assert met is True + + def test_simple_or_prereq_not_satisfied(self): + """MATH 201 needs MATH 101 or MATH 105. Neither completed → not eligible.""" + met, reason = prerequisites_met("MATH 201", set(), self.prereq_df) + assert met is False + assert "MATH 101" in reason or "MATH 105" in reason + + def test_and_prereq_partially_met(self): + """STATS 302 needs MATH 201, MATH 202, MATH 206, COMPSCI 201 (AND logic). + Student only has MATH 201 and COMPSCI 201 → still not eligible.""" + completed = {"MATH 201", "COMPSCI 201"} + met, reason = prerequisites_met("STATS 302", completed, self.prereq_df) + assert met is False + # At least one missing prereq should be cited + assert "MATH 202" in reason or "MATH 206" in reason + + def test_and_prereq_fully_met(self): + """STATS 302 with all 4 prereqs completed → eligible.""" + completed = {"MATH 201", "MATH 202", "MATH 206", "COMPSCI 201"} + met, reason = prerequisites_met("STATS 302", completed, self.prereq_df) + assert met is True + + def test_prereq_entry_with_no_course_codes(self): + """A prereq text like 'sophomore standing' has no course codes — + our heuristic returns eligible with an unstructured note.""" + # Inject a row with standing-only prereq + extra = pd.DataFrame( + [ + [ + 99, + "01/01/2024", + "ECON", + "201", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "Sophomore standing required", + ] + ], + columns=[f"col{i}" for i in range(14)], + ) + df = pd.concat([self.prereq_df, extra], ignore_index=True) + met, reason = prerequisites_met("ECON 201", set(), df) + assert met is True # best-effort: pass through with note + assert "Unstructured" in reason or reason == "" + + +# --------------------------------------------------------------------------- +# Integration tests: CourseRecommenderOuter (TC1–TC6) +# --------------------------------------------------------------------------- + + +class TestCourseRecommenderScenarios: + """End-to-end test scenarios evaluating the full recommendation pipeline.""" + + # TC1 — Simple: freshman with no completed courses + def test_tc1_no_completed_courses( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC1 (Simple): Student has no completed courses. + + Expected behaviour: + - Courses with unmet prerequisites → 'not eligible' section + - Courses with no prerequisites that are offered → 'recommended' + - Courses offered but requiring prereqs not yet met → 'not eligible' + """ + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + result = recommender(major="data science", completed_courses=[]) + + # Should produce a structured report, not an error + assert "## Course Recommendation" in result + assert "data-science" in result.lower() + + # COMPSCI 201 has no prereq entry → eligible. It IS offered. + assert "COMPSCI 201" in result + # MATH 201 needs MATH 101 or MATH 105 → not eligible with empty completions + assert "MATH 201" in result + # STATS 302 needs 4 prereqs → not eligible + assert "STATS 302" in result + + # The recommended section should exist + assert "Recommended" in result + + # TC2 — Normal: student with foundational courses done + def test_tc2_student_with_calculus_done( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC2 (Normal): Student has completed MATH 101 (and thus MATH 201 is now eligible). + + Expected: + - MATH 201: prereq (MATH 101 or MATH 105) satisfied → recommended (it IS offered) + - COMPSCI 201: no prereq → recommended (offered) + - STATS 302: still needs MATH 201, MATH 202, MATH 206 → not eligible + """ + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + result = recommender(major="data science", completed_courses=["MATH 101"]) + + # MATH 201 should now appear under the recommended section + lines = result.split("\n") + recommended_section = False + math201_in_recommended = False + for line in lines: + if "Recommended" in line and "eligible" in line.lower(): + recommended_section = True + if ( + recommended_section + and "MATH 201" in line + and line.strip().startswith("-") + ): + math201_in_recommended = True + break + if line.startswith("###") and "Recommended" not in line: + recommended_section = False + + assert math201_in_recommended, ( + "MATH 201 should appear in the recommended section when MATH 101 is completed.\n" + f"Full output:\n{result}" + ) + + # STATS 302 should still be in 'not eligible' + assert ( + "prerequisites not met" in result.lower() + or "not eligible" in result.lower() + ) + + # TC3 — OR prerequisite: second option satisfies + def test_tc3_or_prereq_satisfied_by_alternate( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC3 (OR prereq): Student has MATH 105 (not MATH 101). + + MATH 201 prerequisite is 'MATH 101 or MATH 105'. + Having only MATH 105 should still make the student eligible. + """ + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + result = recommender(major="data science", completed_courses=["MATH 105"]) + + lines = result.split("\n") + recommended_section = False + math201_in_recommended = False + for line in lines: + if "Recommended" in line and "eligible" in line.lower(): + recommended_section = True + if ( + recommended_section + and "MATH 201" in line + and line.strip().startswith("-") + ): + math201_in_recommended = True + break + if line.startswith("###") and "Recommended" not in line: + recommended_section = False + + assert math201_in_recommended, ( + "MATH 201 should be recommended when MATH 105 is done (OR prereq).\n" + f"Full output:\n{result}" + ) + + # TC4 — Complex multi-prereq chain: STATS 302 with all prereqs done + def test_tc4_all_prereqs_for_stats302( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC4 (Complex): Student completed all prereqs for STATS 302. + + STATS 302 requires MATH 201, MATH 202, MATH 206, and COMPSCI 201. + With all four completed, STATS 302 should appear as recommended (it IS offered). + """ + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + completed = ["MATH 201", "MATH 202", "MATH 206", "COMPSCI 201"] + result = recommender(major="data science", completed_courses=completed) + + lines = result.split("\n") + recommended_section = False + stats302_in_recommended = False + for line in lines: + if "Recommended" in line and "eligible" in line.lower(): + recommended_section = True + if ( + recommended_section + and "STATS 302" in line + and line.strip().startswith("-") + ): + stats302_in_recommended = True + break + if line.startswith("###") and "Recommended" not in line: + recommended_section = False + + assert stats302_in_recommended, ( + "STATS 302 should be recommended once all 4 prereqs are completed.\n" + f"Full output:\n{result}" + ) + + # TC5 — Edge case: all required courses already completed + def test_tc5_all_courses_completed( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC5 (Edge case): Student has completed every required course. + + Should return a 'you have completed all required courses' message instead + of an empty recommendation grid. + """ + # Extract all codes from the requirements fixtures to simulate full completion + all_ds = [ + "COMPSCI 201", + "STATS 302", + "STATS 303", + "STATS 401", + "MATH 201", + "MATH 202", + "MATH 206", + "COMPSCI 301", + ] + all_core = ["GCHINA 101", "GLOCHALL 201", "ETHLDR 201"] + completed = all_ds + all_core + + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + result = recommender(major="data science", completed_courses=completed) + + assert ( + "completed all required courses" in result.lower() + ), f"Expected completion message, got:\n{result}" + + # TC6 — Unknown major + def test_tc6_unknown_major( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC6: User provides a major that doesn't exist in the requirements dir.""" + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + result = recommender(major="astrology and witchcraft", completed_courses=[]) + assert "No matching major" in result + + # TC7 — Common-core courses: ETHLDR 201 chain + def test_tc7_common_core_prereq_chain( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + """TC7 (Complex): Tests the common-core prerequisite chain. + + GLOCHALL 201 requires GCHINA 101. + ETHLDR 201 requires GLOCHALL 201. + Student who has GCHINA 101 but not GLOCHALL 201: + - GLOCHALL 201 → eligible (offered in fixture) + - ETHLDR 201 → not eligible (GLOCHALL 201 not yet completed) + """ + recommender = CourseRecommenderOuter( + requirements_dir=req_dir, + classdata_csv_path=sample_classdata_real_csv, + prereq_csv_path=prereq_csv_with_ds_courses, + ) + # Also complete DS courses so the report focuses on core courses + completed = [ + "GCHINA 101", # satisfies GLOCHALL 201's prereq + "COMPSCI 201", + "STATS 302", + "STATS 303", + "STATS 401", + "MATH 201", + "MATH 202", + "MATH 206", + "COMPSCI 301", + ] + result = recommender(major="data science", completed_courses=completed) + + # GLOCHALL 201 should be recommended (GCHINA 101 done, GLOCHALL offered) + lines = result.split("\n") + recommended_section = False + glochall_recommended = False + + for line in lines: + if "Recommended" in line and "eligible" in line.lower(): + recommended_section = True + if ( + recommended_section + and "GLOCHALL 201" in line + and line.strip().startswith("-") + ): + glochall_recommended = True + if line.startswith("###") and "Recommended" not in line: + recommended_section = False + + assert glochall_recommended, ( + "GLOCHALL 201 should be recommended (GCHINA 101 done, GLOCHALL 201 is offered).\n" + f"Full output:\n{result}" + ) + assert ( + "ETHLDR 201" in result + ), "ETHLDR 201 should appear somewhere in the report" + + +# --------------------------------------------------------------------------- +# CourseRecommenderOuter: infrastructure / span tests +# --------------------------------------------------------------------------- + + +class TestCourseRecommenderInfra: + def test_returns_callable( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + fn = CourseRecommenderOuter( + req_dir, sample_classdata_real_csv, prereq_csv_with_ds_courses + ) + assert callable(fn) + + def test_nonexistent_requirements_dir_raises( + self, mock_span_ctx, sample_classdata_real_csv, prereq_csv_with_ds_courses + ): + fn = CourseRecommenderOuter( + "/nonexistent/path", sample_classdata_real_csv, prereq_csv_with_ds_courses + ) + with pytest.raises(FileNotFoundError): + fn(major="data science", completed_courses=[]) + + def test_span_status_ok_on_success( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + fn = CourseRecommenderOuter( + req_dir, sample_classdata_real_csv, prereq_csv_with_ds_courses + ) + fn(major="data science", completed_courses=[]) + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + fn = CourseRecommenderOuter( + req_dir, sample_classdata_real_csv, prereq_csv_with_ds_courses + ) + fn(major="data science", completed_courses=[]) + assert mock_span_ctx.set_attributes.called + + def test_report_contains_summary_counts( + self, + mock_span_ctx, + req_dir, + sample_classdata_real_csv, + prereq_csv_with_ds_courses, + ): + fn = CourseRecommenderOuter( + req_dir, sample_classdata_real_csv, prereq_csv_with_ds_courses + ) + result = fn(major="data science", completed_courses=[]) + # Report should have summary header stats + assert "Total required courses" in result + assert "Completed:" in result + assert "Remaining:" in result diff --git a/tests/test_course_schedule.py b/tests/test_course_schedule.py new file mode 100644 index 00000000..3ced34ee --- /dev/null +++ b/tests/test_course_schedule.py @@ -0,0 +1,147 @@ +"""Comprehensive tests for chatdku.core.tools.course_schedule.""" + +import json + +import pandas as pd +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.course_schedule import ( + CourseScheduleLookupOuter, + _lookup, + _parse_course, +) + + +# --------------------------------------------------------------------------- +# _parse_course (pure function — no mocks needed) +# --------------------------------------------------------------------------- + + +class TestParseCourse: + def test_with_space(self): + assert _parse_course("COMPSCI 101") == ("COMPSCI", "101") + + def test_no_separator(self): + assert _parse_course("COMPSCI101") == ("COMPSCI", "101") + + def test_with_hyphen(self): + assert _parse_course("COMPSCI-101") == ("COMPSCI", "101") + + def test_alpha_suffix(self): + assert _parse_course("Chinese 101A") == ("CHINESE", "101A") + + def test_strips_whitespace(self): + assert _parse_course(" COMPSCI 101 ") == ("COMPSCI", "101") + + def test_lowercase_normalised_to_upper(self): + assert _parse_course("math 201") == ("MATH", "201") + + def test_empty_string_raises(self): + with pytest.raises(ValueError): + _parse_course("") + + def test_only_symbols_raises(self): + with pytest.raises(ValueError): + _parse_course("!!!") + + def test_only_numbers_raises(self): + with pytest.raises(ValueError): + _parse_course("12345") + + +# --------------------------------------------------------------------------- +# _lookup (needs a DataFrame, no external mocks) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def schedule_df(): + return pd.DataFrame( + { + "Subject": ["COMPSCI", "COMPSCI", "MATH", "BIOL"], + "Catalog": ["101", "201", "201", "305"], + "Section": ["01", "02", "01", "01"], + "Instructor": ["Alice", "Bob", "Carol", "Dave"], + } + ) + + +class TestLookup: + def test_finds_matching_rows(self, schedule_df): + rows = _lookup("COMPSCI 101", schedule_df) + assert len(rows) == 1 + assert rows[0]["Subject"] == "COMPSCI" + assert rows[0]["Catalog"] == "101" + + def test_returns_empty_for_nonexistent(self, schedule_df): + assert _lookup("ASTROLOGY 1239", schedule_df) == [] + + def test_case_insensitive(self, schedule_df): + rows = _lookup("compsci 201", schedule_df) + assert len(rows) == 1 + + def test_multiple_sections(self, schedule_df): + # Add a second section for COMPSCI 101 + extra = pd.DataFrame( + { + "Subject": ["COMPSCI"], + "Catalog": ["101"], + "Section": ["02"], + "Instructor": ["Eve"], + } + ) + df = pd.concat([schedule_df, extra], ignore_index=True) + rows = _lookup("COMPSCI 101", df) + assert len(rows) == 2 + + +# --------------------------------------------------------------------------- +# CourseScheduleLookupOuter (needs mock_span_ctx + CSV fixture) +# --------------------------------------------------------------------------- + + +class TestCourseScheduleLookupOuter: + def test_returns_callable(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + assert callable(fn) + + def test_single_course_found(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["COMPSCI 101"])) + assert "COMPSCI 101" in result + assert isinstance(result["COMPSCI 101"], list) + assert result["COMPSCI 101"][0]["Instructor"] == "Alice Smith" + + def test_multiple_courses(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["COMPSCI 101", "MATH 201"])) + assert "COMPSCI 101" in result + assert "MATH 201" in result + + def test_course_not_found_message(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["FAKE 999"])) + assert "No schedule found" in result["FAKE 999"] + + def test_mixed_found_and_not_found(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + result = json.loads(fn(["COMPSCI 101", "FAKE 999"])) + assert isinstance(result["COMPSCI 101"], list) + assert "No schedule found" in result["FAKE 999"] + + def test_file_not_found_raises(self, mock_span_ctx): + fn = CourseScheduleLookupOuter("/nonexistent/path.csv") + with pytest.raises(FileNotFoundError): + fn(["COMPSCI 101"]) + + def test_span_status_ok_on_success(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + fn(["COMPSCI 101"]) + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set(self, mock_span_ctx, sample_classdata_csv): + fn = CourseScheduleLookupOuter(sample_classdata_csv) + fn(["COMPSCI 101"]) + assert mock_span_ctx.set_attributes.called diff --git a/tests/test_course_schedule_ret.py b/tests/test_course_schedule_ret.py index d37b8459..b50d97f1 100644 --- a/tests/test_course_schedule_ret.py +++ b/tests/test_course_schedule_ret.py @@ -4,9 +4,6 @@ _parse_course, ) -CSV_PATH = "/tmp/cleaned_classdata.csv" -df = pd.read_csv(CSV_PATH) - def test_parse_course(): assert _parse_course("COMPSCI 101") == ("COMPSCI", "101") @@ -17,4 +14,10 @@ def test_parse_course(): def test_lookup(): + df = pd.DataFrame( + { + "Subject": ["COMPSCI", "MATH"], + "Catalog": ["101", "201"], + } + ) assert _lookup("ASTROLOGY 1239", df) == [] diff --git a/tests/test_llama_index_tools.py b/tests/test_llama_index_tools.py new file mode 100644 index 00000000..b8c6a118 --- /dev/null +++ b/tests/test_llama_index_tools.py @@ -0,0 +1,290 @@ +"""Tests for chatdku.core.tools.llama_index_tools (VectorRetrieverOuter, KeywordRetrieverOuter).""" + +from contextlib import contextmanager +from unittest.mock import MagicMock + +import pytest + +from chatdku.core.tools.retriever.base_retriever import NodeWithScore +from chatdku.core.tools.utils import QueryTimeoutError + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SAMPLE_NODES = [ + NodeWithScore(node_id="1", text="doc one", metadata={"src": "a"}, score=0.9), + NodeWithScore(node_id="2", text="doc two", metadata={"src": "b"}, score=0.8), +] + + +@contextmanager +def fake_timeout(seconds=5): + """Drop-in replacement for the real timeout context manager.""" + + class FakeCtx: + def run(self, func, *args, **kwargs): + return func(*args, **kwargs) + + yield FakeCtx() + + +@contextmanager +def fake_timeout_that_expires(seconds=5): + """Simulates a timeout by raising QueryTimeoutError on .run().""" + + class FakeCtx: + def run(self, func, *args, **kwargs): + raise QueryTimeoutError(f"Query exceeded {seconds} second timeout") + + yield FakeCtx() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def _patch_vector_retriever(monkeypatch): + """Patch VectorRetriever class so no ChromaDB connection is needed.""" + mock_instance = MagicMock() + mock_instance.query_with_tell.return_value = SAMPLE_NODES + mock_cls = MagicMock(return_value=mock_instance) + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.VectorRetriever", mock_cls + ) + return mock_instance + + +@pytest.fixture() +def _patch_keyword_retriever(monkeypatch): + """Patch KeywordRetriever class so no Redis connection is needed.""" + mock_instance = MagicMock() + mock_instance.query_with_tell.return_value = SAMPLE_NODES + mock_cls = MagicMock(return_value=mock_instance) + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.KeywordRetriever", mock_cls + ) + return mock_instance + + +@pytest.fixture() +def _patch_rerank(monkeypatch): + """Patch the rerank function.""" + mock_rerank = MagicMock(return_value=SAMPLE_NODES[:1]) + monkeypatch.setattr("chatdku.core.tools.llama_index_tools.rerank", mock_rerank) + return mock_rerank + + +@pytest.fixture() +def _patch_timeout(monkeypatch): + """Replace the real timeout with a synchronous fake.""" + monkeypatch.setattr("chatdku.core.tools.llama_index_tools.timeout", fake_timeout) + + +@pytest.fixture() +def _patch_timeout_expires(monkeypatch): + """Replace the real timeout with one that always times out.""" + monkeypatch.setattr( + "chatdku.core.tools.llama_index_tools.timeout", fake_timeout_that_expires + ) + + +# --------------------------------------------------------------------------- +# VectorRetrieverOuter +# --------------------------------------------------------------------------- + + +class TestVectorRetrieverOuter: + @pytest.fixture(autouse=True) + def _setup( + self, + mock_get_current_span, + _patch_vector_retriever, + _patch_rerank, + _patch_timeout, + ): + self.mock_retriever = _patch_vector_retriever + self.mock_rerank = _patch_rerank + + def _make(self, **kwargs): + from chatdku.core.tools.llama_index_tools import VectorRetrieverOuter + + defaults = dict( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + defaults.update(kwargs) + return VectorRetrieverOuter(**defaults) + + def test_returns_callable(self): + assert callable(self._make()) + + def test_query_returns_string(self): + fn = self._make() + result = fn("what is DKU?") + assert isinstance(result, str) + + def test_query_calls_retriever(self): + fn = self._make() + fn("what is DKU?") + self.mock_retriever.query_with_tell.assert_called_once() + + def test_with_reranker_calls_rerank(self): + fn = self._make(use_reranker=True) + fn("what is DKU?") + self.mock_rerank.assert_called_once() + + def test_without_reranker_skips_rerank(self): + fn = self._make(use_reranker=False) + fn("what is DKU?") + self.mock_rerank.assert_not_called() + + def test_invalid_search_mode_defaults_to_zero(self): + # Should not raise; logs a warning and defaults to 0 + fn = self._make(search_mode=5) + result = fn("test") + assert isinstance(result, str) + + def test_search_mode_nonzero_without_files_defaults(self): + # search_mode=1 but files=[] → should default to 0 + fn = self._make(search_mode=1, files=[]) + result = fn("test") + assert isinstance(result, str) + + def test_value_error_propagates(self): + self.mock_retriever.query_with_tell.side_effect = ValueError("bad input") + fn = self._make() + with pytest.raises(ValueError, match="bad input"): + fn("test") + + def test_retrieval_failure_raises_exception(self): + self.mock_retriever.query_with_tell.side_effect = RuntimeError( + "connection lost" + ) + fn = self._make() + with pytest.raises(Exception, match="Vector retrieval failed"): + fn("test") + + +class TestVectorRetrieverOuterTimeout: + def test_timeout_raises_exception( + self, + mock_get_current_span, + _patch_vector_retriever, + _patch_rerank, + _patch_timeout_expires, + ): + from chatdku.core.tools.llama_index_tools import VectorRetrieverOuter + + fn = VectorRetrieverOuter( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + with pytest.raises(Exception, match="timed out"): + fn("test") + + +# --------------------------------------------------------------------------- +# KeywordRetrieverOuter +# --------------------------------------------------------------------------- + + +class TestKeywordRetrieverOuter: + @pytest.fixture(autouse=True) + def _setup( + self, + mock_get_current_span, + _patch_keyword_retriever, + _patch_rerank, + _patch_timeout, + ): + self.mock_retriever = _patch_keyword_retriever + self.mock_rerank = _patch_rerank + + def _make(self, **kwargs): + from chatdku.core.tools.llama_index_tools import KeywordRetrieverOuter + + defaults = dict( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + defaults.update(kwargs) + return KeywordRetrieverOuter(**defaults) + + def test_returns_callable(self): + assert callable(self._make()) + + def test_query_string_returns_string(self): + fn = self._make() + result = fn("DKU courses") + assert isinstance(result, str) + + def test_query_list_converts_to_strings(self): + fn = self._make() + result = fn(["term1", 42, "term3"]) + assert isinstance(result, str) + # The function stringifies list items in-place + self.mock_retriever.query_with_tell.assert_called_once() + + def test_query_calls_retriever(self): + fn = self._make() + fn("test query") + self.mock_retriever.query_with_tell.assert_called_once() + + def test_with_reranker_calls_rerank(self): + fn = self._make(use_reranker=True) + fn("test") + self.mock_rerank.assert_called_once() + + def test_without_reranker_skips_rerank(self): + fn = self._make(use_reranker=False) + fn("test") + self.mock_rerank.assert_not_called() + + def test_invalid_search_mode_defaults_to_zero(self): + fn = self._make(search_mode=5) + result = fn("test") + assert isinstance(result, str) + + def test_retrieval_failure_raises_exception(self): + self.mock_retriever.query_with_tell.side_effect = RuntimeError("redis down") + fn = self._make() + with pytest.raises(Exception, match="Keyword retrieval failed"): + fn("test") + + +class TestKeywordRetrieverOuterTimeout: + def test_timeout_raises_exception( + self, + mock_get_current_span, + _patch_keyword_retriever, + _patch_rerank, + _patch_timeout_expires, + ): + from chatdku.core.tools.llama_index_tools import KeywordRetrieverOuter + + fn = KeywordRetrieverOuter( + retriever_top_k=10, + use_reranker=False, + reranker_top_n=5, + user_id="Chat_DKU", + search_mode=0, + files=[], + ) + with pytest.raises(Exception, match="Keyword retriever timeout"): + fn("test") diff --git a/tests/test_major_requirements.py b/tests/test_major_requirements.py new file mode 100644 index 00000000..b8f240b1 --- /dev/null +++ b/tests/test_major_requirements.py @@ -0,0 +1,133 @@ +"""Tests for chatdku.core.tools.major_requirements.""" + +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.major_requirements import ( + MajorRequirementsLookupOuter, + _best_match, + _list_stems, +) + + +# --------------------------------------------------------------------------- +# _best_match (pure — uses thefuzz token_set_ratio under the hood) +# --------------------------------------------------------------------------- + + +class TestBestMatch: + STEMS = [ + "data-science", + "computation-and-design-computer-science", + "behavioral-science-psychology", + "requirements-for-all-majors", + ] + + def test_exact_match(self): + assert _best_match("data science", self.STEMS) == "data-science" + + def test_partial_match(self): + result = _best_match("computer science", self.STEMS) + assert result == "computation-and-design-computer-science" + + def test_no_match_returns_none(self): + assert _best_match("astrology", self.STEMS) is None + + def test_empty_query_returns_none(self): + assert _best_match("", self.STEMS) is None + + def test_requirements_for_all(self): + result = _best_match("requirements for all majors", self.STEMS) + assert result == "requirements-for-all-majors" + + +# --------------------------------------------------------------------------- +# _list_stems +# --------------------------------------------------------------------------- + + +class TestListStems: + def test_returns_sorted_stems(self, tmp_path): + (tmp_path / "b-major.md").write_text("B") + (tmp_path / "a-major.md").write_text("A") + stems = _list_stems(tmp_path) + assert stems == ["a-major", "b-major"] + + def test_ignores_non_md_files(self, tmp_path): + (tmp_path / "readme.txt").write_text("text") + (tmp_path / "data.md").write_text("data") + stems = _list_stems(tmp_path) + assert stems == ["data"] + + def test_empty_dir(self, tmp_path): + assert _list_stems(tmp_path) == [] + + +# --------------------------------------------------------------------------- +# MajorRequirementsLookupOuter (needs mock_span_ctx + tmp dir with .md files) +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def requirements_dir(tmp_path): + """Create a temporary requirements directory with sample .md files.""" + (tmp_path / "data-science.md").write_text( + "# Data Science\n\n- COMPSCI 101\n- STATS 202\n" + ) + (tmp_path / "computation-and-design-computer-science.md").write_text( + "# Computation and Design / Computer Science\n\n- COMPSCI 201\n" + ) + (tmp_path / "requirements-for-all-majors.md").write_text( + "# General Requirements\n\n- WRIT 101\n- MATH 101\n" + ) + return str(tmp_path) + + +class TestMajorRequirementsLookupOuter: + def test_returns_callable(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + assert callable(fn) + + def test_list_returns_all_majors(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("list") + assert "data-science" in result + assert "computation-and-design-computer-science" in result + assert "requirements-for-all-majors" in result + + def test_lookup_returns_file_content(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("data science") + assert "COMPSCI 101" in result + assert "STATS 202" in result + + def test_lookup_prepends_requirements_header(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("data science") + assert result.startswith("# Requirements:") + + def test_no_match_returns_message(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + result = fn("astrology") + assert "No matching major" in result + + def test_nonexistent_directory_raises(self, mock_span_ctx): + fn = MajorRequirementsLookupOuter("/nonexistent/path") + with pytest.raises(FileNotFoundError): + fn("data science") + + def test_empty_directory_raises(self, mock_span_ctx, tmp_path): + fn = MajorRequirementsLookupOuter(str(tmp_path)) + with pytest.raises(FileNotFoundError): + fn("data science") + + def test_span_status_ok_on_success(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + fn("data science") + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set(self, mock_span_ctx, requirements_dir): + fn = MajorRequirementsLookupOuter(requirements_dir) + fn("data science") + assert mock_span_ctx.set_attributes.called diff --git a/tests/test_prerequisites.py b/tests/test_prerequisites.py new file mode 100644 index 00000000..71afdfe4 --- /dev/null +++ b/tests/test_prerequisites.py @@ -0,0 +1,85 @@ +"""Tests for chatdku.core.tools.get_prerequisites.""" + +import pytest +from opentelemetry.trace import StatusCode + +from chatdku.core.tools.get_prerequisites import PrerequisiteLookupOuter, get_prereq + + +# --------------------------------------------------------------------------- +# get_prereq (internal helper) +# --------------------------------------------------------------------------- + + +class TestGetPrereq: + def test_returns_prerequisite_description(self, sample_prereq_csv): + result = get_prereq("COMPSCI 201", sample_prereq_csv) + assert "(Source: DKUHub)" in result + assert "COMPSCI" in result + + def test_uses_latest_effective_date(self, sample_prereq_csv): + """Two rows for COMPSCI 201 — should pick the 09/01/2024 entry.""" + result = get_prereq("COMPSCI 201", sample_prereq_csv) + assert "COMPSCI 102" in result # only in the newer row + + def test_returns_not_found_for_unknown_course(self, sample_prereq_csv): + result = get_prereq("ASTRO 999", sample_prereq_csv) + assert "No prerequisites found" in result + + def test_empty_description_returns_not_found(self, sample_prereq_csv): + """BIOL 305 has an empty description in col 13.""" + result = get_prereq("BIOL 305", sample_prereq_csv) + assert "No prerequisites found" in result + + def test_file_not_found_raises(self): + with pytest.raises(FileNotFoundError): + get_prereq("COMPSCI 201", "/nonexistent/path.csv") + + def test_handles_extra_spaces_in_course_name(self, sample_prereq_csv): + result = get_prereq("COMPSCI 201", sample_prereq_csv) + # Should still parse — splits on underscore after space→underscore replacement + assert "COMPSCI" in result + + def test_known_course_with_prereqs(self, sample_prereq_csv): + result = get_prereq("MATH 201", sample_prereq_csv) + assert "MATH 101" in result + assert "(Source: DKUHub)" in result + + +# --------------------------------------------------------------------------- +# PrerequisiteLookupOuter (needs mock_span_ctx + sample CSV) +# --------------------------------------------------------------------------- + + +class TestPrerequisiteLookupOuter: + def test_returns_callable(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + assert callable(fn) + + def test_single_course_lookup(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + result = fn(["MATH 201"]) + assert "MATH 101" in result + + def test_multiple_courses_joined_by_newline(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + result = fn(["COMPSCI 201", "MATH 201"]) + assert "\n" in result + assert "COMPSCI" in result + assert "MATH" in result + + def test_file_not_found_propagates(self, mock_span_ctx): + fn = PrerequisiteLookupOuter("/nonexistent/path.csv") + with pytest.raises(FileNotFoundError): + fn(["COMPSCI 201"]) + + def test_span_status_ok_on_success(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + fn(["MATH 201"]) + calls = mock_span_ctx.set_status.call_args_list + assert any(c.args[0].status_code == StatusCode.OK for c in calls if c.args) + + def test_span_attributes_set(self, mock_span_ctx, sample_prereq_csv): + fn = PrerequisiteLookupOuter(sample_prereq_csv) + fn(["MATH 201"]) + assert mock_span_ctx.set_attributes.called diff --git a/tests/test_sql_agent.py b/tests/test_sql_agent.py index eebca60f..ab3d4b98 100644 --- a/tests/test_sql_agent.py +++ b/tests/test_sql_agent.py @@ -10,15 +10,15 @@ import pytest -from chatdku.core.tools.syllabi_tool.generate_sql import ( +from chatdku.core.tools.syllabi.generate_sql import ( _collapse_repeated_lines, _dedupe_lines, _truncate_long_output, ) # Import helpers directly after patching -from chatdku.core.tools.syllabi_tool.query_curriculum_db import ( - QueryCurriculumOuter, +from chatdku.core.tools.syllabi.syllabi_tool import ( + SyllabusLookupOuter, fetch_schema, ) @@ -27,7 +27,7 @@ # We patch setup() and use_phoenix() before importing the module. -query_curriculum_db = QueryCurriculumOuter() +query_curriculum_db = SyllabusLookupOuter() class TestCollapseRepeatedLines: @@ -149,7 +149,7 @@ def mock_db(monkeypatch): ] mock_db_cls = MagicMock(return_value=db_instance) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.DB", + "chatdku.core.tools.syllabi.query_curriculum_db.DB", mock_db_cls, ) return db_instance @@ -160,7 +160,7 @@ def mock_generate_sql(monkeypatch): sql_agent = MagicMock(return_value=FAKE_SQL) mock_cls = MagicMock(return_value=sql_agent) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.GenerateSQL", + "chatdku.core.tools.syllabi.query_curriculum_db.GenerateSQL", mock_cls, ) return sql_agent @@ -175,7 +175,7 @@ def mock_dspy_predict(monkeypatch): predictor_instance = MagicMock(return_value=fake_result) mock_predict_cls = MagicMock(return_value=predictor_instance) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.dspy.Predict", + "chatdku.core.tools.syllabi.query_curriculum_db.dspy.Predict", mock_predict_cls, ) return fake_result @@ -224,7 +224,7 @@ def test_sql_execution_error_handled_gracefully( Exception("DB is down"), # SQL execution fails ] monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.DB", + "chatdku.core.tools.syllabi.query_curriculum_db.DB", MagicMock(return_value=db_instance), ) # Should not raise; error is caught and passed to the LM as text @@ -238,7 +238,7 @@ def test_think_section_stripped_from_result( fake_result.result = "internalClean answer." predictor_instance = MagicMock(return_value=fake_result) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.dspy.Predict", + "chatdku.core.tools.syllabi.query_curriculum_db.dspy.Predict", MagicMock(return_value=predictor_instance), ) result, internal = query_curriculum_db("Test query.", "Test query.") @@ -252,7 +252,7 @@ def test_repeated_lines_in_lm_output_collapsed( fake_result.result = "\n".join(["answer"] * 20) predictor_instance = MagicMock(return_value=fake_result) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.dspy.Predict", + "chatdku.core.tools.syllabi.query_curriculum_db.dspy.Predict", MagicMock(return_value=predictor_instance), ) result, internal = query_curriculum_db( @@ -266,7 +266,7 @@ def mock_db_outer(monkeypatch): db_instance = MagicMock() db_instance.execute.side_effect = [FAKE_SCHEMA_ROWS, FAKE_ROWS] monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.DB", + "chatdku.core.tools.syllabi.query_curriculum_db.DB", MagicMock(return_value=db_instance), ) return db_instance @@ -276,7 +276,7 @@ def mock_db_outer(monkeypatch): def mock_generate_sql_outer(monkeypatch): sql_agent = MagicMock(return_value=FAKE_SQL) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.GenerateSQL", + "chatdku.core.tools.syllabi.query_curriculum_db.GenerateSQL", MagicMock(return_value=sql_agent), ) return sql_agent @@ -288,7 +288,7 @@ def mock_dspy_predict_outer(monkeypatch): fake_result.result = "Two math courses found." predictor_instance = MagicMock(return_value=fake_result) monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.dspy.Predict", + "chatdku.core.tools.syllabi.query_curriculum_db.dspy.Predict", MagicMock(return_value=predictor_instance), ) return fake_result @@ -298,7 +298,7 @@ class TestQueryCurriculumOuter: def test_returns_tuple( self, mock_db_outer, mock_generate_sql_outer, mock_dspy_predict_outer ): - fn = QueryCurriculumOuter() + fn = SyllabusLookupOuter() result = fn("What courses are there?", "What courses are there?") assert isinstance(result, tuple) assert len(result) == 2 @@ -306,14 +306,14 @@ def test_returns_tuple( def test_result_is_string( self, mock_db_outer, mock_generate_sql_outer, mock_dspy_predict_outer ): - fn = QueryCurriculumOuter() + fn = SyllabusLookupOuter() result_str, internal = fn("List CS courses.", "List CS courses.") assert isinstance(result_str, str) def test_internal_result_contains_sql( self, mock_db_outer, mock_generate_sql_outer, mock_dspy_predict_outer ): - fn = QueryCurriculumOuter() + fn = SyllabusLookupOuter() _, internal = fn("Find courses.", "Find courses.") assert "sql" in internal assert isinstance(internal["sql"], str) @@ -325,16 +325,16 @@ def test_sql_error_reflected_in_output(self, monkeypatch, mock_generate_sql_oute Exception("timeout"), ] monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.DB", + "chatdku.core.tools.syllabi.query_curriculum_db.DB", MagicMock(return_value=db_instance), ) fake_result = MagicMock() fake_result.result = "SQL execution error: timeout" monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.dspy.Predict", + "chatdku.core.tools.syllabi.query_curriculum_db.dspy.Predict", MagicMock(return_value=MagicMock(return_value=fake_result)), ) - fn = QueryCurriculumOuter() + fn = SyllabusLookupOuter() result_str, _ = fn("Any query.", "Any query.") assert isinstance(result_str, str) @@ -344,10 +344,10 @@ def test_think_section_stripped( fake_result = MagicMock() fake_result.result = "skipReal answer." monkeypatch.setattr( - "chatdku.core.tools.syllabi_tool.query_curriculum_db.dspy.Predict", + "chatdku.core.tools.syllabi.query_curriculum_db.dspy.Predict", MagicMock(return_value=MagicMock(return_value=fake_result)), ) - fn = QueryCurriculumOuter() + fn = SyllabusLookupOuter() result_str, _ = fn("Q.", "Q.") assert "" not in result_str assert "Real answer." in result_str @@ -360,10 +360,10 @@ def test_no_tracer_attribute_on_config( monkeypatch, ): """Runs without error when config has no tracer (uses nullcontext).""" - import chatdku.core.tools.syllabi_tool.query_curriculum_db as mod + import chatdku.core.tools.syllabi.syllabi_tool as mod fake_config = MagicMock(spec=[]) # no tracer attribute monkeypatch.setattr(mod, "config", fake_config) - fn = QueryCurriculumOuter() + fn = SyllabusLookupOuter() result_str, internal = fn("Any query.", "Any query.") assert isinstance(result_str, str) diff --git a/utils/startup_timer.py b/utils/startup_timer.py new file mode 100644 index 00000000..689920f0 --- /dev/null +++ b/utils/startup_timer.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +"""Startup timing diagnostic — identifies slow initialization steps.""" +import os +import time + +os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") + +_t0 = time.perf_counter() + + +def lap(label: str, t_prev: float) -> float: + t = time.perf_counter() + print(f" {t - t_prev:6.2f}s {label}") + return t + + +print("=== ChatDKU startup timer ===") +t = _t0 + +# --- imports --- +print("\n[imports]") + +import dspy + +t = lap("import dspy", t) # noqa: E402,E401 +from chatdku.config import config + +t = lap("import config", t) # noqa: E402,E401 + +from chatdku.core.tools.retriever.keyword_retriever import KeywordRetriever + +t = lap("import KeywordRetriever (+ NLTK check)", t) # noqa: E402,E401,E501 +from chatdku.core.tools.retriever.vector_retriever import VectorRetriever + +t = lap("import VectorRetriever (+ chromadb)", t) # noqa: E402,E401,E501 +from chatdku.core.tools.major_requirements import MajorRequirementsLookupOuter + +t = lap("import MajorRequirementsLookupOuter", t) # noqa: E402,E401,E501 +from chatdku.core.tools.syllabi.syllabi_tool import SyllabusLookupOuter + +t = lap("import QueryCurriculumOuter (+ DB)", t) # noqa: E402,E401,E501 +t = lap("import PrerequisiteLookupOuter", t) # noqa: E402,E401,E501 +t = lap("import CourseScheduleLookupOuter", t) # noqa: E402,E401,E501 +from chatdku.setup import setup, use_phoenix + +t = lap("import setup, use_phoenix", t) # noqa: E402,E401 + +# --- initialization --- +print("\n[initialization]") + +setup() +t = lap("setup() — embed model + tokenizer", t) +use_phoenix() +t = lap("use_phoenix() — OTel register", t) + +lm = dspy.LM( + model="openai/" + config.backup_llm, + api_base=config.backup_llm_url, + api_key=config.llm_api_key, + model_type="chat", + max_tokens=config.output_window, + temperature=config.llm_temperature, +) +dspy.configure(lm=lm) +t = lap("dspy.LM() + configure()", t) + +user_id = "Chat_DKU" +KeywordRetriever(retriever_top_k=10, user_id=user_id, search_mode=0, files=[]) +t = lap("KeywordRetriever() init", t) + +VectorRetriever(retriever_top_k=10, user_id=user_id, search_mode=0, files=[]) +t = lap("VectorRetriever() init", t) + +MajorRequirementsLookupOuter(config.major_requirements_dir) +t = lap("MajorRequirementsLookupOuter() init", t) + +SyllabusLookupOuter() +t = lap("QueryCurriculumOuter() init (DB connect + schema fetch)", t) + +print(f"\n=== total: {t - _t0:.2f}s ===")