From 9c10c88778b285782376ed4234b9f7ed8bf289c4 Mon Sep 17 00:00:00 2001 From: Kirill Krivov Date: Wed, 30 Apr 2025 09:27:18 +0300 Subject: [PATCH] Prevent db_uid from consisting of None --- core/message/from_message.py | 5 +- scenarios/logging/logger_constants.py | 14 ++-- .../field/field_filler_description.py | 2 +- smart_kit/start_points/base_main_loop.py | 74 ++++++++++------- .../start_points/main_loop_async_http.py | 79 ++++++++----------- 5 files changed, 90 insertions(+), 84 deletions(-) diff --git a/core/message/from_message.py b/core/message/from_message.py index 51b1c4cd..65fd69c8 100644 --- a/core/message/from_message.py +++ b/core/message/from_message.py @@ -162,7 +162,9 @@ def session_id(self) -> Optional[str]: # database user_id @property - def db_uid(self) -> str: + def db_uid(self) -> Optional[str]: + if self.uid is None or self.channel is None: + return None return "{}_{}".format(self.uid, self.channel) @property @@ -244,7 +246,6 @@ def callback_id(self) -> Optional[str]: def has_callback_id(self): return self._callback_id is not None or self.headers.get(self._callback_id_header_name) is not None - # noinspection PyMethodMayBeStatic def generate_new_callback_id(self) -> str: from smart_kit.start_points.main_loop_http import HttpMainLoop if issubclass(get_app_config().MAIN_LOOP, HttpMainLoop): diff --git a/scenarios/logging/logger_constants.py b/scenarios/logging/logger_constants.py index 80282d25..1f963853 100644 --- a/scenarios/logging/logger_constants.py +++ b/scenarios/logging/logger_constants.py @@ -57,10 +57,10 @@ CHECK_HOSTNAME = "check_hostname" NORMALIZE_INTENT_VALUE = "normalize_intent" -SKIPPED_INTENT_VALUE = 'skipped_intent' -INVALID_INTENT_VALUE = 'invalid_intent' -CONTAINER_NAME_VALUE = 'container_name' -CONTAINER_REQUIREMENT_CHECK_VALUE = 'container_requirement_check_value' +SKIPPED_INTENT_VALUE = "skipped_intent" +INVALID_INTENT_VALUE = "invalid_intent" +CONTAINER_NAME_VALUE = "container_name" +CONTAINER_REQUIREMENT_CHECK_VALUE = "container_requirement_check_value" AB_GROUPS_MESSAGE = "%(key_name)s=%(ab_groups)s" @@ -70,7 +70,7 @@ DIALOG_SCENARIO_MESSAGE = "Dialog manager run last scenario info: %(key_name)s=%(chosen_scenario)s, " \ "distance=%(distance)s, scenario_description=%(scenario_description)s, " \ "root_id=%(root_id)s" -INVALID_INTENT_MESSAGE = 'Invalid intent, intent=%(intent_id)s' +INVALID_INTENT_MESSAGE = "Invalid intent, intent=%(intent_id)s" CLASSIFIER_MESSAGE = "classifier: %(classifier_name)s, result: %(result)s, " \ "weights: %(score)s, time: %(time)s ms" CHOSEN_SCENARIO_MESSAGE = "%(key_name)s=%(chosen_scenario)s" @@ -88,8 +88,8 @@ FILLER_RESULT_MESSAGE = "%(key_name)s=%(filler_result)s" CHOSEN_NODE_ID_MESSAGE = "%(key_name)s=%(chosen_node_id)s" SKIPPED_INTENT_MESSAGE = "Skipped intent %(intent_id)s, scenario is not available" -CONTAINER_REQUIREMENT_CHECK_MESSAGE = 'Classifier container requirement check: result: ' \ - '%(requirement_check)s, container: %(container_name)s' +CONTAINER_REQUIREMENT_CHECK_MESSAGE = "Classifier container requirement check: result: " \ + "%(requirement_check)s, container: %(container_name)s" MESSAGE_ID = "message_id" CALLBACK_ID_HEADER = "app_callback_id" diff --git a/scenarios/scenario_models/field/field_filler_description.py b/scenarios/scenario_models/field/field_filler_description.py index 90b41b87..26408856 100644 --- a/scenarios/scenario_models/field/field_filler_description.py +++ b/scenarios/scenario_models/field/field_filler_description.py @@ -48,7 +48,7 @@ def _log_params(self): } def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> None: + params: Dict[str, Any] = None) -> Any: return None def on_extract_error(self, text_preprocessing_result, user, params=None): diff --git a/smart_kit/start_points/base_main_loop.py b/smart_kit/start_points/base_main_loop.py index 16436113..8f5ef08b 100644 --- a/smart_kit/start_points/base_main_loop.py +++ b/smart_kit/start_points/base_main_loop.py @@ -1,6 +1,4 @@ -# coding=utf-8 - -from typing import Type, Iterable +from typing import Type, Iterable, Optional import asyncio import signal @@ -12,9 +10,9 @@ from core.monitoring.monitoring import monitoring from core.monitoring.healthcheck_handler import RootResource from core.monitoring.twisted_server import TwistedServer -from core.model.base_user import BaseUser from core.basic_models.parametrizers.parametrizer import BasicParametrizer from core.message.msg_validator import MessageValidator +from scenarios.user.user_model import User from smart_kit.start_points.postprocess import PostprocessMainLoop from smart_kit.models.smartapp_model import SmartAppModel @@ -23,7 +21,7 @@ class BaseMainLoop: def __init__( self, model: SmartAppModel, - user_cls: Type[BaseUser], + user_cls: Type[User], parametrizer_cls: Type[BasicParametrizer], postprocessor_cls: Type[PostprocessMainLoop], settings, @@ -98,20 +96,26 @@ def _init_monitoring_config(self, template_settings): monitoring.apply_config(monitoring_config) monitoring.init_metrics(app_name=self.app_name) - async def load_user(self, db_uid, message: SmartAppFromMessage): - db_data = None - load_error = False + async def load_user(self, db_uid: Optional[str], message: SmartAppFromMessage) -> User: + if db_uid is None: + log("Failed to load user data as db_uid is None. Will use empty user.", level="ERROR") + return self.get_user(message, db_data=None, load_error=True) + return await self._load_user(db_uid, message) + + async def _load_user(self, db_uid: str, message: SmartAppFromMessage) -> User: try: db_data = await self.db_adapter.get(db_uid) except (DBAdapterException, ValueError): log("Failed to get user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, log_const.REQUEST_VALUE: message.as_str}, level="ERROR") - load_error = True monitoring.counter_load_error(self.app_name) # to skip message when load failed raise + return self.get_user(message, db_data, load_error=False) + + def get_user(self, message: SmartAppFromMessage, db_data: Optional[dict], load_error: bool) -> User: return self.user_cls( - message.uid, + id=message.uid, message=message, db_data=db_data, settings=self.settings, @@ -120,27 +124,37 @@ async def load_user(self, db_uid, message: SmartAppFromMessage): load_error=load_error ) - async def save_user(self, db_uid, user, message: SmartAppFromMessage): - no_collisions = True + async def save_user(self, db_uid: Optional[str], user: User, message: SmartAppFromMessage) -> bool: + """ + :return: True if there were no any collisions when saving, False otherwise + """ + if db_uid is None: + log("User %(uid)s will not be saved as db_uid is None", + user=user, level="ERROR", params={"uid": user.id}) + return True + return await self._save_user(db_uid, user, message) + + async def _save_user(self, db_uid: str, user: User, message: SmartAppFromMessage) -> bool: if user.do_not_save: - log("User %(uid)s will not saved", user=user, params={"uid": user.id, - log_const.KEY_NAME: "user_will_not_saved"}) - else: - no_collisions = True - try: - str_data = user.raw_str - if user.initial_db_data and self.user_save_check_for_collisions: - no_collisions = await self.db_adapter.replace_if_equals(db_uid, - sample=user.initial_db_data, - data=str_data) - else: - await self.db_adapter.save(db_uid, str_data) - except (DBAdapterException, ValueError): - log("Failed to set user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, - log_const.REQUEST_VALUE: message.as_str}, level="ERROR") - monitoring.counter_save_error(self.app_name) - if not no_collisions: - monitoring.counter_save_collision(self.app_name) + log("User %(uid)s will not be saved", user=user, + params={"uid": user.id, log_const.KEY_NAME: "user_will_not_saved"}) + return True + + no_collisions = True + try: + str_data = user.raw_str + if user.initial_db_data and self.user_save_check_for_collisions: + no_collisions = await self.db_adapter.replace_if_equals(db_uid, + sample=user.initial_db_data, + data=str_data) + else: + await self.db_adapter.save(db_uid, str_data) + except (DBAdapterException, ValueError): + log("Failed to set user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, + log_const.REQUEST_VALUE: message.as_str}, level="ERROR") + monitoring.counter_save_error(self.app_name) + if not no_collisions: + monitoring.counter_save_collision(self.app_name) return no_collisions def run(self): diff --git a/smart_kit/start_points/main_loop_async_http.py b/smart_kit/start_points/main_loop_async_http.py index 8ced8cf5..d87bbf02 100644 --- a/smart_kit/start_points/main_loop_async_http.py +++ b/smart_kit/start_points/main_loop_async_http.py @@ -1,6 +1,5 @@ import asyncio import json -import typing from functools import cached_property import aiohttp @@ -12,6 +11,7 @@ from core.message.from_message import SmartAppFromMessage from core.monitoring.monitoring import monitoring from core.utils.stats_timer import StatsTimer +from scenarios.user.user_model import User from smart_kit.message.smartapp_to_message import SmartAppToMessage from smart_kit.start_points.main_loop_http import BaseHttpMainLoop @@ -41,7 +41,7 @@ async def close_db(self, app): def masking_fields(self): return self.settings["template_settings"].get("masking_fields") - async def load_user(self, db_uid, message): + async def _load_user(self, db_uid: str, message: SmartAppFromMessage) -> User: db_data = None load_error = False try: @@ -54,51 +54,42 @@ async def load_user(self, db_uid, message): log_const.REQUEST_VALUE: message.as_str}, level="ERROR") load_error = True monitoring.counter_load_error(self.app_name) - return self.user_cls( - message.uid, - message=message, - db_data=db_data, - settings=self.settings, - descriptions=self.model.scenario_descriptions, - parametrizer_cls=self.parametrizer_cls, - load_error=load_error - ) + return self.get_user(message, db_data, load_error) - async def save_user(self, db_uid, user, message): - no_collisions = True + async def _save_user(self, db_uid, user, message): if user.do_not_save: - log("User %(uid)s will not saved", user=user, params={"uid": user.id, - log_const.KEY_NAME: "user_will_not_saved"}) - else: + log("User %(uid)s will not be saved", user=user, + params={"uid": user.id, log_const.KEY_NAME: "user_will_not_saved"}) + return True + + no_collisions = True + try: + str_data = user.raw_str - no_collisions = True - try: - str_data = user.raw_str - - if self.db_adapter.IS_ASYNC: - if user.initial_db_data and self.user_save_check_for_collisions: - no_collisions = await self.db_adapter.replace_if_equals( - db_uid, - sample=user.initial_db_data, - data=str_data - ) - else: - await self.db_adapter.save(db_uid, str_data) + if self.db_adapter.IS_ASYNC: + if user.initial_db_data and self.user_save_check_for_collisions: + no_collisions = await self.db_adapter.replace_if_equals( + db_uid, + sample=user.initial_db_data, + data=str_data + ) + else: + await self.db_adapter.save(db_uid, str_data) + else: + if user.initial_db_data and self.user_save_check_for_collisions: + no_collisions = await self.db_adapter.replace_if_equals( + db_uid, + sample=user.initial_db_data, + data=str_data + ) else: - if user.initial_db_data and self.user_save_check_for_collisions: - no_collisions = await self.db_adapter.replace_if_equals( - db_uid, - sample=user.initial_db_data, - data=str_data - ) - else: - await self.db_adapter.save(db_uid, str_data) - except (DBAdapterException, ValueError): - log("Failed to set user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, - log_const.REQUEST_VALUE: message.as_str}, level="ERROR") - monitoring.counter_save_error(self.app_name) - if not no_collisions: - monitoring.counter_save_collision(self.app_name) + await self.db_adapter.save(db_uid, str_data) + except (DBAdapterException, ValueError): + log("Failed to set user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, + log_const.REQUEST_VALUE: message.as_str}, level="ERROR") + monitoring.counter_save_error(self.app_name) + if not no_collisions: + monitoring.counter_save_collision(self.app_name) return no_collisions def run(self): @@ -111,7 +102,7 @@ def run(self): def stop(self, signum, frame): pass - async def handle_message(self, message: SmartAppFromMessage) -> typing.Tuple[int, str, SmartAppToMessage]: + async def handle_message(self, message: SmartAppFromMessage) -> tuple[int, str, SmartAppToMessage]: if not message.validate(): answer = SmartAppToMessage( self.BAD_REQUEST_COMMAND, message=message, request=None, masking_fields=self.masking_fields)