Skip to content
Merged
36 changes: 26 additions & 10 deletions app/ldap_protocol/ldap_requests/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import AsyncGenerator, ClassVar

from loguru import logger
from pydantic import Field
from sqlalchemy import Select, and_, delete, func, or_, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -110,7 +111,7 @@ class ModifyRequest(BaseRequest):
# NOTE: If the old value was changed (for example, in _delete)
# in one method, then you need to have access to the old value
# from other methods (for example, from _add)
_old_vals: dict[str, str | None] = {}
old_vals: dict[str, str | None] = Field(default_factory=dict)

@classmethod
def from_data(cls, data: list[ASN1Row]) -> "ModifyRequest":
Expand Down Expand Up @@ -143,7 +144,7 @@ async def _update_password_expiration(
return

if not (
change.modification.type == "krbpasswordexpiration"
change.l_type == "krbpasswordexpiration"
and change.modification.vals[0] == "19700101000000Z"
):
return
Expand Down Expand Up @@ -284,10 +285,10 @@ async def handle(

except MODIFY_EXCEPTION_STACK as err:
await ctx.session.rollback()
result_code, message = self._match_bad_response(err)
result_code, error_message = self._match_bad_response(err)
yield ModifyResponse(
result_code=result_code,
message=message,
error_message=error_message,
)
return

Expand Down Expand Up @@ -333,6 +334,9 @@ def _match_bad_response(self, err: BaseException) -> tuple[LDAPCodes, str]:
case ModifyForbiddenError():
return LDAPCodes.OPERATIONS_ERROR, str(err)

case KRBAPIRenamePrincipalError():
return LDAPCodes.UNAVAILABLE, "Kerberos error"

case KRBAPIPrincipalNotFoundError():
return LDAPCodes.UNAVAILABLE, "Kerberos error"

Expand Down Expand Up @@ -632,8 +636,8 @@ def _need_to_cache_samaccountname_old_value(
return bool(
directory.entity_type
and directory.entity_type.name == EntityTypeNames.COMPUTER
and change.modification.type == "sAMAccountName"
and not self._old_vals.get(change.modification.type),
and change.l_type == "samaccountname"
and not self.old_vals.get(change.modification.type),
)

async def _delete(
Expand Down Expand Up @@ -689,7 +693,7 @@ async def _delete(
if self._need_to_cache_samaccountname_old_value(change, directory):
vals = directory.attributes_dict.get(change.modification.type)
if vals:
self._old_vals[change.modification.type] = vals[0]
self.old_vals[change.modification.type] = vals[0]

if attrs:
del_query = (
Expand Down Expand Up @@ -826,14 +830,13 @@ async def _add( # noqa: C901
password_use_cases: PasswordPolicyUseCases,
password_utils: PasswordUtils,
) -> None:
base_dir = None
attrs = []

if change.l_type in ("memberof", "member", "primarygroupid"):
await self._add_group_attrs(change, directory, session)
return

base_dir = await self._get_base_dir(directory, session)

for value in change.modification.vals:
if change.l_type == "useraccountcontrol":
uac_val = int(value)
Expand Down Expand Up @@ -923,6 +926,12 @@ async def _add( # noqa: C901
new_user_principal_name = str(new_value)
new_sam_account_name = new_user_principal_name.split("@")[0] # noqa: E501 # fmt: skip
elif change.l_type == "samaccountname":
if not base_dir:
base_dir = await self._get_base_dir(
directory,
session,
)

new_sam_account_name = str(new_value)
new_user_principal_name = f"{new_sam_account_name}@{base_dir.name}" # noqa: E501 # fmt: skip

Expand All @@ -946,12 +955,19 @@ async def _add( # noqa: C901
and directory.entity_type
and directory.entity_type.name == EntityTypeNames.COMPUTER
):
if not base_dir:
base_dir = await self._get_base_dir(
directory,
session,
)

await self._modify_computer_samaccountname(
change,
kadmin,
base_dir,
value,
)

attrs.append(
Attribute(
name=change.modification.type,
Expand Down Expand Up @@ -1019,7 +1035,7 @@ async def _modify_computer_samaccountname(
base_dir: Directory,
new_sam_account_name: bytes | str,
) -> None:
old_sam_account_name = self._old_vals.get(change.modification.type)
old_sam_account_name = self.old_vals.get(change.modification.type)
new_sam_account_name = str(new_sam_account_name)

if not old_sam_account_name:
Expand Down
Loading