diff --git a/backend/retrieval.py b/backend/retrieval.py index a8702ab..972e781 100644 --- a/backend/retrieval.py +++ b/backend/retrieval.py @@ -89,10 +89,10 @@ def __init__(self): self.query_char_limit = 8000 # Enable only if everything is present - self.is_enabled = all( + self._is_enabled = all( [self.project_id, self.region, self.index_endpoint_full, self.deployed_id] ) - if not self.is_enabled: + if not self._is_enabled: logger.warning( "Vector search disabled due to incomplete GCP env: " f"project={bool(self.project_id)}, region={bool(self.region)}, " @@ -109,7 +109,7 @@ def __init__(self): self.bq = bigquery.Client(project=self.project_id) except Exception as e: logger.error(f"GCP client initialization failed: {e}") - self.is_enabled = False + self._is_enabled = False return try: @@ -123,7 +123,11 @@ def __init__(self): logger.info(f"Vector search initialized on device={self.device} using {self.embed_model_name}") except Exception as e: logger.error(f"Embedding model initialization failed: {e}") - self.is_enabled = False + self._is_enabled = False + + @property + def is_enabled(self) -> bool: + return getattr(self, '_is_enabled', False) # Embedding def _embed(self, text: str) -> List[float]: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..191a027 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,49 @@ +import os +import sys +import pytest +from unittest.mock import patch, MagicMock + +# Add the backend directory to sys.path so tests can import from it +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +# Mock heavy dependencies that are not installed in the global env +# to allow unit tests to run without them. +mocked_modules = [ + 'langgraph', + 'langgraph.graph', + 'torch', + 'google', + 'google.cloud', + 'google.cloud.aiplatform', + 'google.cloud.bigquery', + 'google.genai', + 'google.genai.types', + 'transformers', + 'ks_search_tool' +] + +for mod in mocked_modules: + sys.modules[mod] = MagicMock() + +# agents.py imports END from langgraph.graph +sys.modules['langgraph.graph'].END = "END" + +@pytest.fixture(autouse=True) +def mock_env_vars(): + env_patcher = patch.dict(os.environ, { + "GOOGLE_API_KEY": "test-key-123", + "GCP_PROJECT_ID": "", + "GEMINI_USE_VERTEX": "false", + "CORS_ALLOW_ORIGINS": "*", + "ENVIRONMENT": "test" + }, clear=False) + + env_patcher.start() + yield + env_patcher.stop() + +@pytest.fixture +def test_client(): + from main import app + from fastapi.testclient import TestClient + return TestClient(app) diff --git a/backend/tests/test_agents.py b/backend/tests/test_agents.py new file mode 100644 index 0000000..73aec5f --- /dev/null +++ b/backend/tests/test_agents.py @@ -0,0 +1,59 @@ +from unittest.mock import patch, AsyncMock, MagicMock + +from agents import ( + _is_more_query, + QueryIntent, + fuse_results, + AgentState, + NeuroscienceAssistant +) + +def test_is_more_query(): + assert _is_more_query("next 10") == 10 + assert _is_more_query("show 5") == 5 + assert _is_more_query("more 20") == 20 + + assert _is_more_query("more") is None + assert _is_more_query("continue") is None + + assert _is_more_query("find rat electrophysiology") is None + assert _is_more_query("") is None + +def test_fuse_results(): + state: AgentState = { + "session_id": "test_session", + "query": "rat data", + "history": [], + "keywords": [], + "effective_query": "rat data", + "intents": [], + "ks_results": [{"_id": "doc_common", "_score": 10.0}, {"_id": "doc_ks_only", "_score": 5.0}], + "vector_results": [{"id": "doc_common", "similarity": 0.8}, {"id": "doc_vec_only", "similarity": 0.9}], + "final_results": [], + "all_results": [], + "start_number": 1, + "previous_text": "", + "final_response": "", + } + + new_state = fuse_results(state) + all_res = new_state["all_results"] + + assert len(all_res) == 3 + + # doc_common score: vector (0.8 * 0.6 = 0.48) + ks (10.0 * 0.4 = 4.0) = 4.48 + # doc_vec_only score: vector (0.9 * 0.6 = 0.54) + ks (0) = 0.54 + # doc_ks_only score: vector (0) + ks (5.0 * 0.4 = 2.0) = 2.0 + doc_ids = [res.get("id") or res.get("_id") for res in all_res] + assert doc_ids == ["doc_common", "doc_ks_only", "doc_vec_only"] + +def test_neuroscience_assistant_reset(): + assistant = NeuroscienceAssistant() + + assistant.chat_history["session_123"] = ["User: Hello", "Assistant: Hi"] + assistant.session_memory["session_123"] = {"page": 1} + + assistant.reset_session("session_123") + + assert "session_123" not in assistant.chat_history + assert "session_123" not in assistant.session_memory diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py new file mode 100644 index 0000000..bc9a374 --- /dev/null +++ b/backend/tests/test_main.py @@ -0,0 +1,57 @@ +from unittest.mock import patch, AsyncMock + +def test_root_endpoint(test_client): + response = test_client.get("/") + assert response.status_code == 200 + assert "KnowledgeSpace AI Backend is running" in response.json()["message"] + +def test_health_check_endpoint(test_client): + response = test_client.get("/health") + assert response.status_code == 200 + assert response.json()["status"] == "healthy" + +@patch("main.asyncio.wait_for") +def test_api_health_endpoint(mock_wait_for, test_client): + mock_wait_for.return_value = True + + response = test_client.get("/api/health") + assert response.status_code == 200 + + data = response.json() + assert data["status"] == "healthy" + assert data["components"]["vector_search"] == "enabled" + assert data["components"]["llm"] == "enabled" # enabled via conftest env patch + +@patch("main.assistant") +def test_chat_endpoint_success(mock_assistant, test_client): + mock_assistant.handle_chat = AsyncMock(return_value="Found 3 datasets for rat hippocampus...") + + payload = { + "query": "find rat hippocampus data", + "session_id": "session_123", + "reset": False + } + + response = test_client.post("/api/chat", json=payload) + + assert response.status_code == 200 + data = response.json() + assert data["response"] == "Found 3 datasets for rat hippocampus..." + assert "process_time" in data["metadata"] + assert data["metadata"]["session_id"] == "session_123" + assert data["metadata"]["reset"] is False + + mock_assistant.handle_chat.assert_called_once_with( + session_id="session_123", + query="find rat hippocampus data", + reset=False + ) + +@patch("main.assistant") +def test_session_reset_endpoint(mock_assistant, test_client): + response = test_client.post("/api/session/reset", json={"session_id": "session_456"}) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + mock_assistant.reset_session.assert_called_once_with("session_456") diff --git a/backend/tests/test_retrieval.py b/backend/tests/test_retrieval.py new file mode 100644 index 0000000..3c29dcc --- /dev/null +++ b/backend/tests/test_retrieval.py @@ -0,0 +1,43 @@ +from unittest.mock import patch, MagicMock + +from retrieval import VertexRetriever, get_retriever +from local_retriever import LocalRetriever + +def test_local_retriever(): + retriever = LocalRetriever() + assert retriever.is_enabled is True + assert retriever.search("test query") == [] + +@patch("retrieval.os.getenv") +@patch("retrieval.aiplatform") +@patch("retrieval.bigquery") +@patch("retrieval.AutoTokenizer") +@patch("retrieval.AutoModel") +@patch("retrieval.torch.cuda.is_available", return_value=False) +def test_vertex_retriever_initialization(mock_cuda, mock_model, mock_tokenizer, mock_bq, mock_ai, mock_getenv): + def mock_env(key, default=""): + mapping = { + "GCP_PROJECT_ID": "test-project", + "GCP_REGION": "us-central1", + "INDEX_ENDPOINT_ID_FULL": "projects/123/locations/us-central1/indexEndpoints/456", + "DEPLOYED_INDEX_ID": "test_index", + } + return mapping.get(key, default) + + mock_getenv.side_effect = mock_env + + # prevent actual model downloads during test + mock_model.from_pretrained.return_value.eval.return_value.to.return_value = MagicMock() + + retriever = VertexRetriever() + + assert retriever.is_enabled is True + mock_ai.init.assert_called_once_with(project="test-project", location="us-central1") + mock_bq.Client.assert_called_once_with(project="test-project") + mock_tokenizer.from_pretrained.assert_called_once() + mock_model.from_pretrained.assert_called_once() + +def test_get_retriever_fallback(): + # fallback happens because we stripped GCP env vars in conftest + retriever = get_retriever() + assert isinstance(retriever, LocalRetriever)