diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 6f34b69b..8e672a46 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -26,10 +26,11 @@ async def chats(session: SessionDep, current_user: CurrentUser): @router.get("/{chart_id}") -async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant): +async def get_chat(session: SessionDep, current_user: CurrentUser, chart_id: int, current_assistant: CurrentAssistant, + trans: Trans): def inner(): return get_chat_with_records(chart_id=chart_id, session=session, current_user=current_user, - current_assistant=current_assistant) + current_assistant=current_assistant, trans=trans) return await asyncio.to_thread(inner) @@ -108,7 +109,7 @@ async def start_chat(session: SessionDep, current_user: CurrentUser): @router.post("/recommend_questions/{chat_record_id}") async def recommend_questions(session: SessionDep, current_user: CurrentUser, chat_record_id: int, - current_assistant: CurrentAssistant, articles_number: Optional[int] = 4): + current_assistant: CurrentAssistant, articles_number: Optional[int] = 4): def _return_empty(): yield 'data:' + orjson.dumps({'content': '[]', 'type': 'recommended_question'}).decode() + '\n\n' @@ -134,6 +135,7 @@ def _err(_e: Exception): return StreamingResponse(llm_service.await_result(), media_type="text/event-stream") + @router.get("/recent_questions/{datasource_id}") async def recommend_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int): return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id) diff --git a/backend/apps/chat/curd/chat.py b/backend/apps/chat/curd/chat.py index d6e2bfe1..9ff45ea4 100644 --- a/backend/apps/chat/curd/chat.py +++ b/backend/apps/chat/curd/chat.py @@ -12,7 +12,7 @@ from apps.datasource.crud.recommended_problem import get_datasource_recommended, get_datasource_recommended_chart from apps.datasource.models.datasource import CoreDatasource, DsRecommendedProblem from apps.system.crud.assistant import AssistantOutDsFactory -from common.core.deps import CurrentAssistant, SessionDep, CurrentUser +from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans from common.utils.utils import extract_nested_json @@ -191,7 +191,7 @@ def get_chat_with_records_with_data(session: SessionDep, chart_id: int, current_ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: CurrentUser, - current_assistant: CurrentAssistant, with_data: bool = False) -> ChatInfo: + current_assistant: CurrentAssistant, with_data: bool = False,trans: Trans = None) -> ChatInfo: chat = session.get(Chat, chart_id) if not chat: raise Exception(f"Chat with id {chart_id} not found") @@ -200,7 +200,7 @@ def get_chat_with_records(session: SessionDep, chart_id: int, current_user: Curr if current_assistant and current_assistant.type in dynamic_ds_types: out_ds_instance = AssistantOutDsFactory.get_instance(current_assistant) - ds = out_ds_instance.get_ds(chat.datasource) + ds = out_ds_instance.get_ds(chat.datasource,trans) else: ds = session.get(CoreDatasource, chat.datasource) if chat.datasource else None diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index 1f1e5ec6..da807965 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -20,6 +20,7 @@ from common.core.sqlbot_cache import cache from common.utils.aes_crypto import simple_aes_decrypt from common.utils.utils import equals_ignore_case, string_to_numeric_hash +from common.core.deps import Trans @cache(namespace=CacheNamespace.EMBEDDED_INFO, cacheName=CacheName.ASSISTANT_INFO, keyExpression="assistant_id") @@ -143,12 +144,12 @@ def get_ds_from_api(self): raise Exception(f"Failed to get datasource list from {endpoint}, error: {result_json.get('message')}") else: raise Exception(f"Failed to get datasource list from {endpoint}, status code: {res.status_code}") - + def get_first_element(self, text: str): parts = re.split(r'[,;]', text.strip()) first_domain = parts[0].strip() return first_domain - + def get_complete_endpoint(self, endpoint: str) -> str | None: if endpoint.startswith("http://") or endpoint.startswith("https://"): return endpoint @@ -158,8 +159,8 @@ def get_complete_endpoint(self, endpoint: str) -> str | None: if ',' in domain_text or ';' in domain_text: return (self.request_origin.strip('/') if self.request_origin else self.get_first_element(domain_text).strip('/')) + endpoint else: - return f"{domain_text}{endpoint}" - + return f"{domain_text}{endpoint}" + def get_simple_ds_list(self): if self.ds_list: return [{'id': ds.id, 'name': ds.name, 'description': ds.comment} for ds in self.ds_list] @@ -205,14 +206,14 @@ def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> st return schema_str - def get_ds(self, ds_id: int): + def get_ds(self, ds_id: int,trans: Trans = None): if self.ds_list: for ds in self.ds_list: if ds.id == ds_id: return ds else: raise Exception("Datasource list is not found.") - raise Exception(f"Datasource with id {ds_id} not found.") + raise Exception(f"Datasource id {ds_id} is not found." if trans is None else trans('i18n_data_training.datasource_id_not_found', key=ds_id)) def convert2schema(self, ds_dict: dict, config: dict[any]) -> AssistantOutDsSchema: id_marker: str = ''