diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index 7ab3333fb..bc55d6403 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, ClassVar from loguru import logger +from pydantic import Field from sqlalchemy import Select, and_, delete, func, or_, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -110,7 +111,7 @@ class ModifyRequest(BaseRequest): # NOTE: If the old value was changed (for example, in _delete) # in one method, then you need to have access to the old value # from other methods (for example, from _add) - _old_vals: dict[str, str | None] = {} + old_vals: dict[str, str | None] = Field(default_factory=dict) @classmethod def from_data(cls, data: list[ASN1Row]) -> "ModifyRequest": @@ -143,7 +144,7 @@ async def _update_password_expiration( return if not ( - change.modification.type == "krbpasswordexpiration" + change.l_type == "krbpasswordexpiration" and change.modification.vals[0] == "19700101000000Z" ): return @@ -284,10 +285,10 @@ async def handle( except MODIFY_EXCEPTION_STACK as err: await ctx.session.rollback() - result_code, message = self._match_bad_response(err) + result_code, error_message = self._match_bad_response(err) yield ModifyResponse( result_code=result_code, - message=message, + error_message=error_message, ) return @@ -333,6 +334,9 @@ def _match_bad_response(self, err: BaseException) -> tuple[LDAPCodes, str]: case ModifyForbiddenError(): return LDAPCodes.OPERATIONS_ERROR, str(err) + case KRBAPIRenamePrincipalError(): + return LDAPCodes.UNAVAILABLE, "Kerberos error" + case KRBAPIPrincipalNotFoundError(): return LDAPCodes.UNAVAILABLE, "Kerberos error" @@ -632,8 +636,8 @@ def _need_to_cache_samaccountname_old_value( return bool( directory.entity_type and directory.entity_type.name == EntityTypeNames.COMPUTER - and change.modification.type == "sAMAccountName" - and not self._old_vals.get(change.modification.type), + and change.l_type == "samaccountname" + and not self.old_vals.get(change.modification.type), ) async def _delete( @@ -689,7 +693,7 @@ async def _delete( if self._need_to_cache_samaccountname_old_value(change, directory): vals = directory.attributes_dict.get(change.modification.type) if vals: - self._old_vals[change.modification.type] = vals[0] + self.old_vals[change.modification.type] = vals[0] if attrs: del_query = ( @@ -826,14 +830,13 @@ async def _add( # noqa: C901 password_use_cases: PasswordPolicyUseCases, password_utils: PasswordUtils, ) -> None: + base_dir = None attrs = [] if change.l_type in ("memberof", "member", "primarygroupid"): await self._add_group_attrs(change, directory, session) return - base_dir = await self._get_base_dir(directory, session) - for value in change.modification.vals: if change.l_type == "useraccountcontrol": uac_val = int(value) @@ -923,6 +926,12 @@ async def _add( # noqa: C901 new_user_principal_name = str(new_value) new_sam_account_name = new_user_principal_name.split("@")[0] # noqa: E501 # fmt: skip elif change.l_type == "samaccountname": + if not base_dir: + base_dir = await self._get_base_dir( + directory, + session, + ) + new_sam_account_name = str(new_value) new_user_principal_name = f"{new_sam_account_name}@{base_dir.name}" # noqa: E501 # fmt: skip @@ -946,12 +955,19 @@ async def _add( # noqa: C901 and directory.entity_type and directory.entity_type.name == EntityTypeNames.COMPUTER ): + if not base_dir: + base_dir = await self._get_base_dir( + directory, + session, + ) + await self._modify_computer_samaccountname( change, kadmin, base_dir, value, ) + attrs.append( Attribute( name=change.modification.type, @@ -1019,7 +1035,7 @@ async def _modify_computer_samaccountname( base_dir: Directory, new_sam_account_name: bytes | str, ) -> None: - old_sam_account_name = self._old_vals.get(change.modification.type) + old_sam_account_name = self.old_vals.get(change.modification.type) new_sam_account_name = str(new_sam_account_name) if not old_sam_account_name: