Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added backend/__init__.py
Empty file.
42 changes: 24 additions & 18 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# main.py
import asyncio
import json
import os
import time
import asyncio
from typing import Optional, Dict, Any
from datetime import datetime
from datetime import UTC, datetime
from typing import Any, Dict, Optional

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
import json

from agents import NeuroscienceAssistant
# from agents import NeuroscienceAssistant
from backend.agents import NeuroscienceAssistant

load_dotenv()

Expand All @@ -34,12 +35,11 @@
# Initialize the assistant with vector search agent on startup
assistant = NeuroscienceAssistant()


# Models
class ChatMessage(BaseModel):
query: str = Field(..., description="The user's query")
session_id: Optional[str] = Field(
default="default", description="Session ID"
)
session_id: Optional[str] = Field(default="default", description="Session ID")
reset: Optional[bool] = Field(
default=False,
description="If true, clears server-side session history before handling the message",
Expand All @@ -53,10 +53,13 @@ class ChatResponse(BaseModel):

# Lightweight health helpers


def _vector_check_sync() -> bool:

try:
from retrieval import Retriever # local import to avoid import penalty on startup
from retrieval import \
Retriever # local import to avoid import penalty on startup

r = Retriever()
return bool(getattr(r, "is_enabled", False))
except Exception:
Expand All @@ -65,6 +68,7 @@ def _vector_check_sync() -> bool:

# Routes


@app.get("/", tags=["General"])
async def root():
return {"message": "KnowledgeSpace AI Backend is running", "version": "2.0.0"}
Expand All @@ -75,7 +79,7 @@ async def health_check():
"""Cheap health for Docker healthcheck / load balancers."""
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
"service": "knowledge-space-agent-backend",
"version": "2.0.0",
}
Expand All @@ -98,14 +102,18 @@ async def health():

components = {
"vector_search": "enabled" if vector_enabled else "disabled",
"llm": "enabled" if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID")) else "disabled",
"llm": (
"enabled"
if (os.getenv("GOOGLE_API_KEY") or os.getenv("GCP_PROJECT_ID"))
else "disabled"
),
"keyword_search": "enabled",
}
return {
"status": "healthy",
"version": "2.0.0",
"components": components,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
}


Expand All @@ -122,7 +130,7 @@ async def chat_endpoint(msg: ChatMessage):
metadata = {
"process_time": process_time,
"session_id": msg.session_id,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(UTC).isoformat(),
"reset": bool(msg.reset),
}
return ChatResponse(response=response_text, metadata=metadata)
Expand All @@ -138,8 +146,6 @@ async def chat_endpoint(msg: ChatMessage):
)




@app.post("/api/session/reset", tags=["Chat"])
async def reset_session(payload: Dict[str, str]):
sid = (payload or {}).get("session_id") or "default"
Expand All @@ -154,7 +160,7 @@ async def reset_session(payload: Dict[str, str]):
"main:app",
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8000")),
reload=True,
reload=True,
log_level="info",
proxy_headers=True,
)
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ dev = [
"flake8>=6.0.0",
"mypy>=1.0.0",
]


[tool.pytest.ini_options]
pythonpath = ["."]
testpaths = ["tests"]

[tool.setuptools.packages.find]
include = ["backend*"]
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from fastapi.testclient import TestClient

from backend.main import app


@pytest.fixture
def client():
return TestClient(app)
42 changes: 42 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# from unittest.mock import patch


# def test_chat_endpoint_success(client):
# with patch("backend.main.assistant") as mock_assistant:
# mock_assistant.chat.return_value = "Mocked response"

# response = client.post(
# "/api/chat",
# json={"query": "Hello"}
# )

# assert response.status_code == 200
# assert response.json()["response"] == "Mocked response"


# def test_chat_endpoint_validation_error(client):
# response = client.post(
# "/api/chat",
# json={}
# )

# assert response.status_code == 422


from unittest.mock import AsyncMock, patch


def test_chat_endpoint_success(client):
with patch("backend.main.assistant") as mock_assistant:
mock_assistant.handle_chat = AsyncMock(return_value="Mocked response")

response = client.post("/api/chat", json={"query": "Hello"})

assert response.status_code == 200
assert response.json()["response"] == "Mocked response"


def test_chat_endpoint_validation_error(client):
response = client.post("/api/chat", json={})

assert response.status_code == 422
10 changes: 10 additions & 0 deletions tests/test_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def test_health_endpoint(client):
response = client.get("/health")
assert response.status_code == 200
assert "status" in response.json()


def test_api_health_endpoint(client):
response = client.get("/api/health")
assert response.status_code == 200
assert "status" in response.json()
72 changes: 72 additions & 0 deletions tests/test_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import sys
from unittest.mock import MagicMock

# Add backend folder to path
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "backend"))
)

# Mock assistant BEFORE importing main
import builtins
import types

mock_agents = types.ModuleType("agents")
mock_agents.NeuroscienceAssistant = MagicMock
sys.modules["agents"] = mock_agents

from fastapi.testclient import TestClient
from main import app

client = TestClient(app)


def test_root_endpoint():
response = client.get("/")
assert response.status_code == 200
assert "KnowledgeSpace AI Backend is running" in response.json()["message"]


def test_health_endpoint():
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data


def test_api_health_endpoint():
response = client.get("/api/health")
assert response.status_code == 200

data = response.json()
assert data["status"] == "healthy"
assert "components" in data
assert "version" in data


def test_chat_endpoint_success(monkeypatch):
async def mock_handle_chat(session_id, query, reset):
return "Mocked response"

# Patch the assistant instance inside main
from main import assistant

monkeypatch.setattr(assistant, "handle_chat", mock_handle_chat)

response = client.post(
"/api/chat", json={"query": "What is neuroscience?", "session_id": "test123"}
)

assert response.status_code == 200
data = response.json()
assert data["response"] == "Mocked response"
assert "metadata" in data


def test_chat_validation_error():
response = client.post(
"/api/chat", json={"session_id": "abc"} # Missing required 'query'
)

assert response.status_code == 422