Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added backend/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions backend/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from langgraph.graph import StateGraph, END

from ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
from retrieval import Retriever
from backend.ks_search_tool import general_search, general_search_async, global_fuzzy_keyword_search
from backend.retrieval import Retriever

# LLM (Gemini) client setup
try:
Expand Down
63 changes: 43 additions & 20 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# main.py
import asyncio
import json
import os
import time
import asyncio
from typing import Optional, Dict, Any
from datetime import datetime
from datetime import UTC, datetime
from typing import Any, Dict, Optional

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
import json

from agents import NeuroscienceAssistant
# from agents import NeuroscienceAssistant
from backend.agents import NeuroscienceAssistant

from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.responses import JSONResponse

load_dotenv()

Expand All @@ -23,6 +29,17 @@
version="2.0.0",
)

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter


@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please slow down."},
)

app.add_middleware(
CORSMiddleware,
allow_origins=[o.strip() for o in os.getenv("CORS_ALLOW_ORIGINS", "*").split(",")],
Expand All @@ -34,12 +51,11 @@
# Initialize the assistant with vector search agent on startup
assistant = NeuroscienceAssistant()


# Models
class ChatMessage(BaseModel):
query: str = Field(..., description="The user's query")
session_id: Optional[str] = Field(
default="default", description="Session ID"
)
session_id: Optional[str] = Field(default="default", description="Session ID")
reset: Optional[bool] = Field(
default=False,
description="If true, clears server-side session history before handling the message",
Expand All @@ -53,10 +69,13 @@ class ChatResponse(BaseModel):

# Lightweight health helpers


def _vector_check_sync() -> bool:

try:
from retrieval import Retriever # local import to avoid import penalty on startup
from retrieval import \
Retriever # local import to avoid import penalty on startup

r = Retriever()
return bool(getattr(r, "is_enabled", False))
except Exception:
Expand All @@ -65,6 +84,7 @@ def _vector_check_sync() -> bool:

# Routes


@app.get("/", tags=["General"])
async def root():
return {"message": "KnowledgeSpace AI Backend is running", "version": "2.0.0"}
Expand All @@ -75,7 +95,7 @@ async def health_check():
"""Cheap health for Docker healthcheck / load balancers."""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
"service": "knowledge-space-agent-backend",
"version": "2.0.0",
}
Expand All @@ -98,19 +118,24 @@ async def health():

components = {
"vector_search": "enabled" if vector_enabled else "disabled",
"llm": "enabled" if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID")) else "disabled",
"llm": (
"enabled"
if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID"))
else "disabled"
),
"keyword_search": "enabled",
}
return {
"status": "healthy",
"version": "2.0.0",
"components": components,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
}


@app.post("/api/chat", response_model=ChatResponse, tags=["Chat"])
async def chat_endpoint(msg: ChatMessage):
@limiter.limit("10/minute")
async def chat_endpoint(request: Request, msg: ChatMessage):
try:
start_time = time.time()
response_text = await assistant.handle_chat(
Expand All @@ -122,7 +147,7 @@ async def chat_endpoint(msg: ChatMessage):
metadata = {
"process_time": process_time,
"session_id": msg.session_id,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
"reset": bool(msg.reset),
}
return ChatResponse(response=response_text, metadata=metadata)
Expand All @@ -138,8 +163,6 @@ async def chat_endpoint(msg: ChatMessage):
)




@app.post("/api/session/reset", tags=["Chat"])
async def reset_session(payload: Dict[str, str]):
sid = (payload or {}).get("session_id") or "default"
Expand All @@ -154,7 +177,7 @@ async def reset_session(payload: Dict[str, str]):
"main:app",
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8000")),
reload=True,
reload=True,
log_level="info",
proxy_headers=True,
)
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"torch>=2.7.1",
"tqdm>=4.67.1",
"uvicorn>=0.34.3",
"slowapi = ^0.1.9"
]

[project.optional-dependencies]
Expand All @@ -38,3 +39,11 @@ dev = [
"flake8>=6.0.0",
"mypy>=1.0.0",
]


[tool.pytest.ini_options]
pythonpath = ["."]
testpaths = ["tests"]

[tool.setuptools.packages.find]
include = ["backend*"]
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from fastapi.testclient import TestClient

from backend.main import app


@pytest.fixture
def client():
return TestClient(app)
42 changes: 42 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# from unittest.mock import patch


# def test_chat_endpoint_success(client):
# with patch("backend.main.assistant") as mock_assistant:
# mock_assistant.chat.return_value = "Mocked response"

# response = client.post(
# "/api/chat",
# json={"query": "Hello"}
# )

# assert response.status_code == 200
# assert response.json()["response"] == "Mocked response"


# def test_chat_endpoint_validation_error(client):
# response = client.post(
# "/api/chat",
# json={}
# )

# assert response.status_code == 422


from unittest.mock import AsyncMock, patch


def test_chat_endpoint_success(client):
with patch("backend.main.assistant") as mock_assistant:
mock_assistant.handle_chat = AsyncMock(return_value="Mocked response")

response = client.post("/api/chat", json={"query": "Hello"})

assert response.status_code == 200
assert response.json()["response"] == "Mocked response"


def test_chat_endpoint_validation_error(client):
response = client.post("/api/chat", json={})

assert response.status_code == 422
10 changes: 10 additions & 0 deletions tests/test_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def test_health_endpoint(client):
response = client.get("/health")
assert response.status_code == 200
assert "status" in response.json()


def test_api_health_endpoint(client):
response = client.get("/api/health")
assert response.status_code == 200
assert "status" in response.json()
69 changes: 69 additions & 0 deletions tests/test_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# import sys
# import os
# from unittest.mock import MagicMock

# # Add backend folder to path
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "backend")))

# # Mock assistant BEFORE importing main
# import builtins
# import types

# mock_agents = types.ModuleType("agents")
# mock_agents.NeuroscienceAssistant = MagicMock
# sys.modules["agents"] = mock_agents

# from main import app
# from fastapi.testclient import TestClient

# client = TestClient(app)


# def test_root_endpoint():
# response = client.get("/")
# assert response.status_code == 200
# assert "KnowledgeSpace AI Backend is running" in response.json()["message"]


# def test_health_endpoint():
# response = client.get("/health")
# assert response.status_code == 200
# data = response.json()
# assert data["status"] == "healthy"
# assert "timestamp" in data

# def test_api_health_endpoint():
# response = client.get("/api/health")
# assert response.status_code == 200

# data = response.json()
# assert data["status"] == "healthy"
# assert "components" in data
# assert "version" in data

# def test_chat_endpoint_success(monkeypatch):
# async def mock_handle_chat(session_id, query, reset):
# return "Mocked response"

# # Patch the assistant instance inside main
# from main import assistant
# monkeypatch.setattr(assistant, "handle_chat", mock_handle_chat)

# response = client.post(
# "/api/chat",
# json={"query": "What is neuroscience?", "session_id": "test123"}
# )

# assert response.status_code == 200
# data = response.json()
# assert data["response"] == "Mocked response"
# assert "metadata" in data


# def test_chat_validation_error():
# response = client.post(
# "/api/chat",
# json={"session_id": "abc"} # Missing required 'query'
# )

# assert response.status_code == 422