Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added backend/tests/__init__.py
Empty file.
57 changes: 57 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import sys
import pytest
from unittest.mock import MagicMock

# Set test environment variables before any imports
os.environ["GOOGLE_API_KEY"] = "test-key-for-testing"
os.environ["RATE_LIMIT"] = "100/minute"
os.environ["CORS_ALLOW_ORIGINS"] = "*"

# Mock heavy dependencies that are not needed for API tests
sys.modules["torch"] = MagicMock()
sys.modules["google.cloud"] = MagicMock()
sys.modules["google.cloud.aiplatform"] = MagicMock()
sys.modules["google.cloud.bigquery"] = MagicMock()
sys.modules["vertexai"] = MagicMock()
sys.modules["vertexai.generative_models"] = MagicMock()
sys.modules["google.generativeai"] = MagicMock()
sys.modules["sentence_transformers"] = MagicMock()

# Mock the retrieval module entirely
mock_retriever = MagicMock()
mock_retriever.Retriever = MagicMock()
sys.modules["retrieval"] = mock_retriever

# Mock the agents module with a fake assistant
mock_assistant = MagicMock()
mock_assistant.NeuroscienceAssistant = MagicMock


class FakeAssistant:
"""Fake assistant that returns predictable responses for testing."""

async def handle_chat(self, session_id="default", query="", reset=False):
return f"Test response for: {query}"

def reset_session(self, session_id):
pass


mock_agents = MagicMock()
mock_agents.NeuroscienceAssistant = FakeAssistant
sys.modules["agents"] = mock_agents

# Now we can safely import the app
# Add backend to path so main.py can be found
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from main import app
from fastapi.testclient import TestClient


@pytest.fixture
def client():
"""Create a test client for the FastAPI app."""
with TestClient(app) as c:
yield c
86 changes: 86 additions & 0 deletions backend/tests/test_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for health check endpoints."""


class TestRootEndpoint:

def test_root_returns_200(self, client):
response = client.get("/")
assert response.status_code == 200

def test_root_contains_message(self, client):
response = client.get("/")
data = response.json()
assert "message" in data
assert "running" in data["message"].lower()

def test_root_contains_version(self, client):
response = client.get("/")
data = response.json()
assert "version" in data
assert data["version"] == "2.0.0"


class TestHealthEndpoint:

def test_health_returns_200(self, client):
response = client.get("/health")
assert response.status_code == 200

def test_health_status_is_healthy(self, client):
response = client.get("/health")
data = response.json()
assert data["status"] == "healthy"

def test_health_contains_timestamp(self, client):
response = client.get("/health")
data = response.json()
assert "timestamp" in data

def test_health_contains_service_name(self, client):
response = client.get("/health")
data = response.json()
assert data["service"] == "knowledge-space-agent-backend"

def test_health_contains_version(self, client):
response = client.get("/health")
data = response.json()
assert data["version"] == "2.0.0"


class TestApiHealthEndpoint:

def test_api_health_returns_200(self, client):
response = client.get("/api/health")
assert response.status_code == 200

def test_api_health_status_is_healthy(self, client):
response = client.get("/api/health")
data = response.json()
assert data["status"] == "healthy"

def test_api_health_contains_components(self, client):
response = client.get("/api/health")
data = response.json()
assert "components" in data

def test_api_health_components_have_expected_keys(self, client):
response = client.get("/api/health")
components = response.json()["components"]
assert "vector_search" in components
assert "llm" in components
assert "keyword_search" in components

def test_api_health_keyword_search_always_enabled(self, client):
response = client.get("/api/health")
components = response.json()["components"]
assert components["keyword_search"] == "enabled"

def test_api_health_contains_timestamp(self, client):
response = client.get("/api/health")
data = response.json()
assert "timestamp" in data

def test_api_health_contains_version(self, client):
response = client.get("/api/health")
data = response.json()
assert data["version"] == "2.0.0"
86 changes: 86 additions & 0 deletions backend/tests/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for chat and session endpoints."""


class TestChatEndpoint:

def test_chat_endpoint_exists(self, client):
response = client.post("/api/chat", json={"query": "test"})
assert response.status_code != 404

def test_chat_returns_response_field(self, client):
response = client.post("/api/chat", json={"query": "What is a neuron?"})
data = response.json()
assert "response" in data

def test_chat_returns_metadata(self, client):
response = client.post("/api/chat", json={"query": "test query"})
data = response.json()
assert "metadata" in data

def test_chat_metadata_contains_process_time(self, client):
response = client.post("/api/chat", json={"query": "test"})
if response.status_code == 200:
metadata = response.json().get("metadata", {})
assert "process_time" in metadata

def test_chat_metadata_contains_session_id(self, client):
response = client.post("/api/chat", json={"query": "test"})
if response.status_code == 200:
metadata = response.json().get("metadata", {})
assert "session_id" in metadata

def test_chat_with_custom_session_id(self, client):
response = client.post(
"/api/chat",
json={"query": "test", "session_id": "my-session-123"},
)
if response.status_code == 200:
metadata = response.json().get("metadata", {})
assert metadata.get("session_id") == "my-session-123"

def test_chat_missing_query_returns_422(self, client):
response = client.post("/api/chat", json={})
assert response.status_code == 422

def test_chat_get_method_not_allowed(self, client):
response = client.get("/api/chat")
assert response.status_code == 405


class TestSessionResetEndpoint:

def test_reset_endpoint_exists(self, client):
response = client.post(
"/api/session/reset",
json={"session_id": "test-session"},
)
assert response.status_code != 404

def test_reset_returns_ok_status(self, client):
response = client.post(
"/api/session/reset",
json={"session_id": "test-session"},
)
if response.status_code == 200:
data = response.json()
assert data["status"] == "ok"

def test_reset_returns_session_id(self, client):
response = client.post(
"/api/session/reset",
json={"session_id": "my-session"},
)
if response.status_code == 200:
data = response.json()
assert data["session_id"] == "my-session"


class TestUnknownRoutes:

def test_unknown_route_returns_404(self, client):
response = client.get("/this-does-not-exist")
assert response.status_code == 404

def test_unknown_api_route_returns_404(self, client):
response = client.get("/api/nonexistent")
assert response.status_code == 404
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ dev = [
"flake8>=6.0.0",
"mypy>=1.0.0",
]

[tool.pytest.ini_options]
testpaths = ["backend/tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"