diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/main.py b/backend/main.py index 9021161..76ce400 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,18 +1,19 @@ # main.py +import asyncio +import json import os import time -import asyncio -from typing import Optional, Dict, Any -from datetime import datetime +from datetime import UTC, datetime +from typing import Any, Dict, Optional +import uvicorn from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field -import uvicorn -import json -from agents import NeuroscienceAssistant +# from agents import NeuroscienceAssistant +from backend.agents import NeuroscienceAssistant load_dotenv() @@ -34,12 +35,11 @@ # Initialize the assistant with vector search agent on startup assistant = NeuroscienceAssistant() + # Models class ChatMessage(BaseModel): query: str = Field(..., description="The user's query") - session_id: Optional[str] = Field( - default="default", description="Session ID" - ) + session_id: Optional[str] = Field(default="default", description="Session ID") reset: Optional[bool] = Field( default=False, description="If true, clears server-side session history before handling the message", @@ -53,10 +53,13 @@ class ChatResponse(BaseModel): # Lightweight health helpers + def _vector_check_sync() -> bool: - + try: - from retrieval import Retriever # local import to avoid import penalty on startup + from retrieval import \ + Retriever # local import to avoid import penalty on startup + r = Retriever() return bool(getattr(r, "is_enabled", False)) except Exception: @@ -65,6 +68,7 @@ def _vector_check_sync() -> bool: # Routes + @app.get("/", tags=["General"]) async def root(): return {"message": "KnowledgeSpace AI Backend is running", "version": "2.0.0"} @@ -75,7 +79,7 @@ async def health_check(): """Cheap health for Docker healthcheck / load balancers.""" return { "status": "healthy", - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(UTC).isoformat(), "service": "knowledge-space-agent-backend", "version": "2.0.0", } @@ -98,14 +102,18 @@ async def health(): components = { "vector_search": "enabled" if vector_enabled else "disabled", - "llm": "enabled" if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID")) else "disabled", + "llm": ( + "enabled" + if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID")) + else "disabled" + ), "keyword_search": "enabled", } return { "status": "healthy", "version": "2.0.0", "components": components, - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(UTC).isoformat(), } @@ -122,7 +130,7 @@ async def chat_endpoint(msg: ChatMessage): metadata = { "process_time": process_time, "session_id": msg.session_id, - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(UTC).isoformat(), "reset": bool(msg.reset), } return ChatResponse(response=response_text, metadata=metadata) @@ -138,8 +146,6 @@ async def chat_endpoint(msg: ChatMessage): ) - - @app.post("/api/session/reset", tags=["Chat"]) async def reset_session(payload: Dict[str, str]): sid = (payload or {}).get("session_id") or "default" @@ -154,7 +160,7 @@ async def reset_session(payload: Dict[str, str]): "main:app", host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", "8000")), - reload=True, + reload=True, log_level="info", proxy_headers=True, ) diff --git a/pyproject.toml b/pyproject.toml index 67d020c..84b47ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,11 @@ dev = [ "flake8>=6.0.0", "mypy>=1.0.0", ] + + +[tool.pytest.ini_options] +pythonpath = ["."] +testpaths = ["tests"] + +[tool.setuptools.packages.find] +include = ["backend*"] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0b090fe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +from fastapi.testclient import TestClient + +from backend.main import app + + +@pytest.fixture +def client(): + return TestClient(app) diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..1580176 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,42 @@ +# from unittest.mock import patch + + +# def test_chat_endpoint_success(client): +# with patch("backend.main.assistant") as mock_assistant: +# mock_assistant.chat.return_value = "Mocked response" + +# response = client.post( +# "/api/chat", +# json={"query": "Hello"} +# ) + +# assert response.status_code == 200 +# assert response.json()["response"] == "Mocked response" + + +# def test_chat_endpoint_validation_error(client): +# response = client.post( +# "/api/chat", +# json={} +# ) + +# assert response.status_code == 422 + + +from unittest.mock import AsyncMock, patch + + +def test_chat_endpoint_success(client): + with patch("backend.main.assistant") as mock_assistant: + mock_assistant.handle_chat = AsyncMock(return_value="Mocked response") + + response = client.post("/api/chat", json={"query": "Hello"}) + + assert response.status_code == 200 + assert response.json()["response"] == "Mocked response" + + +def test_chat_endpoint_validation_error(client): + response = client.post("/api/chat", json={}) + + assert response.status_code == 422 diff --git a/tests/test_general.py b/tests/test_general.py new file mode 100644 index 0000000..a068540 --- /dev/null +++ b/tests/test_general.py @@ -0,0 +1,10 @@ +def test_health_endpoint(client): + response = client.get("/health") + assert response.status_code == 200 + assert "status" in response.json() + + +def test_api_health_endpoint(client): + response = client.get("/api/health") + assert response.status_code == 200 + assert "status" in response.json() diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..03182d2 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,72 @@ +import os +import sys +from unittest.mock import MagicMock + +# Add backend folder to path +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "backend")) +) + +# Mock assistant BEFORE importing main +import builtins +import types + +mock_agents = types.ModuleType("agents") +mock_agents.NeuroscienceAssistant = MagicMock +sys.modules["agents"] = mock_agents + +from fastapi.testclient import TestClient +from main import app + +client = TestClient(app) + + +def test_root_endpoint(): + response = client.get("/") + assert response.status_code == 200 + assert "KnowledgeSpace AI Backend is running" in response.json()["message"] + + +def test_health_endpoint(): + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + + +def test_api_health_endpoint(): + response = client.get("/api/health") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "healthy" + assert "components" in data + assert "version" in data + + +def test_chat_endpoint_success(monkeypatch): + async def mock_handle_chat(session_id, query, reset): + return "Mocked response" + + # Patch the assistant instance inside main + from main import assistant + + monkeypatch.setattr(assistant, "handle_chat", mock_handle_chat) + + response = client.post( + "/api/chat", json={"query": "What is neuroscience?", "session_id": "test123"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["response"] == "Mocked response" + assert "metadata" in data + + +def test_chat_validation_error(): + response = client.post( + "/api/chat", json={"session_id": "abc"} # Missing required 'query' + ) + + assert response.status_code == 422