diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/agents.py b/backend/agents.py index 80743e4..aad35c2 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -8,8 +8,8 @@ from langgraph.graph import StateGraph, END -from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search -from retrieval import Retriever +from backend.ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search +from backend.retrieval import Retriever # LLM (Gemini) client setup try: diff --git a/backend/main.py b/backend/main.py index 9021161..d89a41b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,18 +1,24 @@ # main.py +import asyncio +import json import os import time -import asyncio -from typing import Optional, Dict, Any -from datetime import datetime +from datetime import UTC, datetime +from typing import Any, Dict, Optional +import uvicorn from dotenv import load_dotenv -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field -import uvicorn -import json -from agents import NeuroscienceAssistant +# from agents import NeuroscienceAssistant +from backend.agents import NeuroscienceAssistant + +from slowapi import Limiter +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded +from fastapi.responses import JSONResponse load_dotenv() @@ -23,6 +29,17 @@ version="2.0.0", ) +limiter = Limiter(key_func=get_remote_address) +app.state.limiter = limiter + + +@app.exception_handler(RateLimitExceeded) +async def rate_limit_handler(request: Request, exc: RateLimitExceeded): + return JSONResponse( + status_code=429, + content={"detail": "Too many requests. Please slow down."}, + ) + app.add_middleware( CORSMiddleware, allow_origins=[o.strip() for o in os.getenv("CORS_ALLOW_ORIGINS", "*").split(",")], @@ -34,12 +51,11 @@ # Initialize the assistant with vector search agent on startup assistant = NeuroscienceAssistant() + # Models class ChatMessage(BaseModel): query: str = Field(..., description="The user's query") - session_id: Optional[str] = Field( - default="default", description="Session ID" - ) + session_id: Optional[str] = Field(default="default", description="Session ID") reset: Optional[bool] = Field( default=False, description="If true, clears server-side session history before handling the message", @@ -53,10 +69,13 @@ class ChatResponse(BaseModel): # Lightweight health helpers + def _vector_check_sync() -> bool: - + try: - from retrieval import Retriever # local import to avoid import penalty on startup + from retrieval import \ + Retriever # local import to avoid import penalty on startup + r = Retriever() return bool(getattr(r, "is_enabled", False)) except Exception: @@ -65,6 +84,7 @@ def _vector_check_sync() -> bool: # Routes + @app.get("/", tags=["General"]) async def root(): return {"message": "KnowledgeSpace AI Backend is running", "version": "2.0.0"} @@ -75,7 +95,7 @@ async def health_check(): """Cheap health for Docker healthcheck / load balancers.""" return { "status": "healthy", - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(UTC).isoformat(), "service": "knowledge-space-agent-backend", "version": "2.0.0", } @@ -98,19 +118,24 @@ async def health(): components = { "vector_search": "enabled" if vector_enabled else "disabled", - "llm": "enabled" if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID")) else "disabled", + "llm": ( + "enabled" + if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID")) + else "disabled" + ), "keyword_search": "enabled", } return { "status": "healthy", "version": "2.0.0", "components": components, - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(UTC).isoformat(), } @app.post("/api/chat", response_model=ChatResponse, tags=["Chat"]) -async def chat_endpoint(msg: ChatMessage): +@limiter.limit("10/minute") +async def chat_endpoint(request: Request, msg: ChatMessage): try: start_time = time.time() response_text = await assistant.handle_chat( @@ -122,7 +147,7 @@ async def chat_endpoint(msg: ChatMessage): metadata = { "process_time": process_time, "session_id": msg.session_id, - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(UTC).isoformat(), "reset": bool(msg.reset), } return ChatResponse(response=response_text, metadata=metadata) @@ -138,8 +163,6 @@ async def chat_endpoint(msg: ChatMessage): ) - - @app.post("/api/session/reset", tags=["Chat"]) async def reset_session(payload: Dict[str, str]): sid = (payload or {}).get("session_id") or "default" @@ -154,7 +177,7 @@ async def reset_session(payload: Dict[str, str]): "main:app", host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", "8000")), - reload=True, + reload=True, log_level="info", proxy_headers=True, ) diff --git a/pyproject.toml b/pyproject.toml index 67d020c..1dc69a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "torch>=2.7.1", "tqdm>=4.67.1", "uvicorn>=0.34.3", + "slowapi = ^0.1.9" ] [project.optional-dependencies] @@ -38,3 +39,11 @@ dev = [ "flake8>=6.0.0", "mypy>=1.0.0", ] + + +[tool.pytest.ini_options] +pythonpath = ["."] +testpaths = ["tests"] + +[tool.setuptools.packages.find] +include = ["backend*"] \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0b090fe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +import pytest +from fastapi.testclient import TestClient + +from backend.main import app + + +@pytest.fixture +def client(): + return TestClient(app) diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..1580176 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,42 @@ +# from unittest.mock import patch + + +# def test_chat_endpoint_success(client): +# with patch("backend.main.assistant") as mock_assistant: +# mock_assistant.chat.return_value = "Mocked response" + +# response = client.post( +# "/api/chat", +# json={"query": "Hello"} +# ) + +# assert response.status_code == 200 +# assert response.json()["response"] == "Mocked response" + + +# def test_chat_endpoint_validation_error(client): +# response = client.post( +# "/api/chat", +# json={} +# ) + +# assert response.status_code == 422 + + +from unittest.mock import AsyncMock, patch + + +def test_chat_endpoint_success(client): + with patch("backend.main.assistant") as mock_assistant: + mock_assistant.handle_chat = AsyncMock(return_value="Mocked response") + + response = client.post("/api/chat", json={"query": "Hello"}) + + assert response.status_code == 200 + assert response.json()["response"] == "Mocked response" + + +def test_chat_endpoint_validation_error(client): + response = client.post("/api/chat", json={}) + + assert response.status_code == 422 diff --git a/tests/test_general.py b/tests/test_general.py new file mode 100644 index 0000000..a068540 --- /dev/null +++ b/tests/test_general.py @@ -0,0 +1,10 @@ +def test_health_endpoint(client): + response = client.get("/health") + assert response.status_code == 200 + assert "status" in response.json() + + +def test_api_health_endpoint(client): + response = client.get("/api/health") + assert response.status_code == 200 + assert "status" in response.json() diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..f466804 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,69 @@ +# import sys +# import os +# from unittest.mock import MagicMock + +# # Add backend folder to path +# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "backend"))) + +# # Mock assistant BEFORE importing main +# import builtins +# import types + +# mock_agents = types.ModuleType("agents") +# mock_agents.NeuroscienceAssistant = MagicMock +# sys.modules["agents"] = mock_agents + +# from main import app +# from fastapi.testclient import TestClient + +# client = TestClient(app) + + +# def test_root_endpoint(): +# response = client.get("/") +# assert response.status_code == 200 +# assert "KnowledgeSpace AI Backend is running" in response.json()["message"] + + +# def test_health_endpoint(): +# response = client.get("/health") +# assert response.status_code == 200 +# data = response.json() +# assert data["status"] == "healthy" +# assert "timestamp" in data + +# def test_api_health_endpoint(): +# response = client.get("/api/health") +# assert response.status_code == 200 + +# data = response.json() +# assert data["status"] == "healthy" +# assert "components" in data +# assert "version" in data + +# def test_chat_endpoint_success(monkeypatch): +# async def mock_handle_chat(session_id, query, reset): +# return "Mocked response" + +# # Patch the assistant instance inside main +# from main import assistant +# monkeypatch.setattr(assistant, "handle_chat", mock_handle_chat) + +# response = client.post( +# "/api/chat", +# json={"query": "What is neuroscience?", "session_id": "test123"} +# ) + +# assert response.status_code == 200 +# data = response.json() +# assert data["response"] == "Mocked response" +# assert "metadata" in data + + +# def test_chat_validation_error(): +# response = client.post( +# "/api/chat", +# json={"session_id": "abc"} # Missing required 'query' +# ) + +# assert response.status_code == 422 \ No newline at end of file