diff --git a/backend/agents.py b/backend/agents.py index 155ce0a..03a8865 100644 --- a/backend/agents.py +++ b/backend/agents.py @@ -63,9 +63,10 @@ def _require_llm_creds() -> None: _GENAI_CLIENT = None +_GENAI_CLIENT_LOCK = asyncio.Lock() -def _get_genai_client(): +async def _get_genai_client(): """ Build a google.genai client for either Vertex (ADC or creds file) or API-key mode. """ @@ -73,17 +74,28 @@ def _get_genai_client(): if _GENAI_CLIENT is not None: return _GENAI_CLIENT - if _use_vertex(): - project = os.getenv("GCP_PROJECT_ID") - location = os.getenv("GCP_REGION") or "europe-west4" - _ensure_google_creds_for_vertex() - _GENAI_CLIENT = genai.Client(vertexai=True, project=project, location=location) - else: - _GENAI_CLIENT = genai.Client(api_key=os.getenv("GOOGLE_API_KEY")) + async with _GENAI_CLIENT_LOCK: + #Double-check inside lock + if _GENAI_CLIENT is not None: + return _GENAI_CLIENT + + if _use_vertex(): + project = os.getenv("GCP_PROJECT_ID") + location = os.getenv("GCP_REGION") or "europe-west4" + _ensure_google_creds_for_vertex() + _GENAI_CLIENT = genai.Client( + vertexai=True, + project=project, + location=location + ) + + else: + _GENAI_CLIENT = genai.Client( + api_key=os.getenv("GOOGLE_API_KEY") + ) return _GENAI_CLIENT - FLASH_MODEL = os.getenv("GEMINI_FLASH_MODEL", "gemini-2.5-flash") FLASH_LITE_MODEL = os.getenv("GEMINI_FLASH_LITE_MODEL", "gemini-2.5-flash-lite") @@ -124,7 +136,7 @@ async def call_gemini_for_keywords(query: str) -> List[str]: No local greeting filters — prompt handles exclusions. Minimal trim+dedupe here. """ _require_llm_creds() - client = _get_genai_client() + client = await _get_genai_client() prompt = ( "Extract important search keywords and multi-word phrases from a neuroscience *data* query.\n" "Return STRICT JSON only:\n" @@ -166,7 +178,7 @@ async def call_gemini_rewrite_with_history(query: str, history: List[str]) -> st Keeps exact tokens and multi-word phrases intact. """ _require_llm_creds() - client = _get_genai_client() + client = await _get_genai_client() last_user_turns = [h for h in history if h.startswith("User: ")] ctx = "\n".join(last_user_turns[-6:]) prompt = ( @@ -201,7 +213,7 @@ async def call_gemini_detect_intents(query: str, history: List[str]) -> List[str - If any data-related tokens exist, prefer data_discovery. """ _require_llm_creds() - client = _get_genai_client() + client = await _get_genai_client() allowed = [i.value for i in QueryIntent] last_user_turns = [h for h in history if h.startswith("User: ")] ctx = "\n".join(last_user_turns[-6:]) @@ -242,7 +254,7 @@ async def call_gemini_for_final_synthesis( ) -> str: _require_llm_creds() - client = _get_genai_client() + client = await _get_genai_client() extras = [] if QueryIntent.ACCESS_DOWNLOAD.value in intents: