diff --git a/scram/route_manager/api/exceptions.py b/scram/route_manager/api/exceptions.py index 0445dbe4..297b4477 100644 --- a/scram/route_manager/api/exceptions.py +++ b/scram/route_manager/api/exceptions.py @@ -28,3 +28,11 @@ class ActiontypeNotAllowed(APIException): status_code = 403 default_detail = "This client is not allowed to use this actiontype" default_code = "actiontype_not_allowed" + + +class NoActiveEntryFound(APIException): + """An active entry was not found.""" + + status_code = 404 + default_detail = "No active entry was found." + default_code = "no_entry_found" diff --git a/scram/route_manager/api/serializers.py b/scram/route_manager/api/serializers.py index 96283335..5c790f54 100644 --- a/scram/route_manager/api/serializers.py +++ b/scram/route_manager/api/serializers.py @@ -6,6 +6,7 @@ from netfields import rest_framework from rest_framework import serializers from rest_framework.fields import CurrentUserDefault +from simple_history.utils import update_change_reason from ..models import ActionType, Client, Entry, IgnoreEntry, Route @@ -67,36 +68,51 @@ class EntrySerializer(serializers.HyperlinkedModelSerializer): else: who = serializers.CharField() comment = serializers.CharField() + originating_scram_instance = serializers.CharField(default="scram_hostname_not_set", read_only=True) + is_active = serializers.BooleanField(default=True, read_only=True) + + def __init__(self, *args, **kwargs): + """Make sure we do not allow changing these fields in our put/patch calls.""" + super().__init__(*args, **kwargs) + if self.instance is not None: + self.fields["route"].read_only = True + self.fields["actiontype"].read_only = True + self.fields["who"].read_only = True class Meta: - """Maps to the Entry model, and specifies the fields exposed by the API.""" + """Map to the Entry model, and specify the fields exposed by the API.""" model = Entry - fields = ["route", "actiontype", "url", "comment", "who"] - - @staticmethod - def get_comment(obj): - """Provide a nicer name for change reason. - - Returns: - string: The change reason that modified the Entry. - """ - return obj.get_change_reason() - - @staticmethod - def create(validated_data): - """Implement custom logic and validates creating a new route.""" - valid_route = validated_data.pop("route") - actiontype = validated_data.pop("actiontype") - comment = validated_data.pop("comment") - - route_instance, _ = Route.objects.get_or_create(route=valid_route) - actiontype_instance = ActionType.objects.get(name=actiontype) - entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance) - - logger.debug("Created entry with comment: %s", comment) + fields = [ + "route", + "actiontype", + "url", + "comment", + "who", + "expiration", + "originating_scram_instance", + "is_active", + ] - return entry_instance + # This needs to be an instance method since thats expected by DRF + # ruff: noqa: PLR6301 + def create(self, validated_data): + """Create or update an Entry, handling duplicates gracefully.""" + route_data = validated_data.pop("route") + actiontype_name = validated_data.pop("actiontype") + comment = validated_data.get("comment", "") + + entry, created = Entry.objects.get_or_create( + route=route_data, actiontype=actiontype_name, defaults=validated_data + ) + + if not created: + for key, value in validated_data.items(): + setattr(entry, key, value) + entry.save() + update_change_reason(entry, comment) + + return entry class IgnoreEntrySerializer(serializers.ModelSerializer): diff --git a/scram/route_manager/api/views.py b/scram/route_manager/api/views.py index 16fc59d7..d1da5f8d 100644 --- a/scram/route_manager/api/views.py +++ b/scram/route_manager/api/views.py @@ -8,16 +8,14 @@ from django.conf import settings from django.core.exceptions import PermissionDenied from django.db.models import Q -from django.http import Http404 -from django.utils.dateparse import parse_datetime from drf_spectacular.utils import extend_schema from rest_framework import status, viewsets from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from simple_history.utils import update_change_reason -from ..models import ActionType, Client, Entry, IgnoreEntry, WebSocketSequenceElement -from .exceptions import ActiontypeNotAllowed, IgnoredRoute, PrefixTooLarge +from ..models import ActionType, Client, Entry, IgnoreEntry, Route, WebSocketSequenceElement +from .exceptions import ActiontypeNotAllowed, IgnoredRoute, NoActiveEntryFound, PrefixTooLarge from .serializers import ActionTypeSerializer, ClientSerializer, EntrySerializer, IgnoreEntrySerializer channel_layer = get_channel_layer() @@ -117,27 +115,25 @@ def perform_create(self, serializer): """Create a new Entry, causing that route to receive the actiontype (i.e. block).""" actiontype = serializer.validated_data["actiontype"] route = serializer.validated_data["route"] - if self.request.user.username: - # This is set if our request comes through the WUI path - who = self.request.user.username - else: + + route_instance, _ = Route.objects.get_or_create(route=route) + actiontype_instance = ActionType.objects.get(name=actiontype) + + if serializer.validated_data.get("who"): # This is set if we pass the "who" through the json data in an API call (like from Zeek) who = serializer.validated_data["who"] - comment = serializer.validated_data["comment"] - tmp_exp = self.request.data.get("expiration", "") + else: + # This is set if our request comes through the WUI path + who = self.request.user.username - try: - expiration = parse_datetime(tmp_exp) - except ValueError: - logger.warning("Could not parse expiration DateTime string: %s", tmp_exp) + comment = serializer.validated_data["comment"] - # Make sure we put in an acceptable sized prefix min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0) if route.prefixlen < min_prefix: raise PrefixTooLarge self.check_client_authorization(actiontype) - self.check_ignore_list(route) + self.check_ignore_list(route_instance) elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num") if not elements: @@ -145,25 +141,53 @@ def perform_create(self, serializer): for element in elements: msg = element.websocketmessage - msg.msg_data[msg.msg_data_route_field] = str(route) + msg.msg_data[msg.msg_data_route_field] = str(route_instance) # Must match a channel name defined in asgi.py async_to_sync(channel_layer.group_send)( f"translator_{actiontype}", {"type": msg.msg_type, "message": msg.msg_data}, ) - serializer.save() - - entry = Entry.objects.get(route__route=route, actiontype__name=actiontype) - if expiration: - entry.expiration = expiration - entry.who = who - entry.is_active = True - entry.comment = comment - entry.originating_scram_instance = settings.SCRAM_HOSTNAME - logger.info("Created entry: %s", entry) - entry.save() + serializer.save( + route=route_instance, + actiontype=actiontype_instance, + who=who, + is_active=True, + comment=comment, + originating_scram_instance=settings.SCRAM_HOSTNAME, + ) + entry = serializer.instance update_change_reason(entry, comment) + logger.info("Created entry %s for route %s", actiontype, route) + + def perform_update(self, serializer): + """Update an existing Entry.""" + comment = serializer.validated_data.get("comment", "") + # Determine who is making this request + if serializer.validated_data.get("who"): + requesting_who = serializer.validated_data["who"] + else: + requesting_who = self.request.user.username + + if serializer.instance.who != requesting_who: + msg = "You can only update your own entries" + raise PermissionDenied(msg) + + serializer.save(who=serializer.instance.who, originating_scram_instance=settings.SCRAM_HOSTNAME) + + entry = serializer.instance + update_change_reason(entry, comment) + logger.info("Updated entry %s", entry) + + def get_object(self): + """Override get_object to use our custom find_entries logic.""" + pk = self.kwargs.get("pk") + entries = self.find_entries(pk, active_filter=True) + + if entries.count() != 1: + raise NoActiveEntryFound + + return entries.first() @staticmethod def find_entries(arg, active_filter=None): @@ -192,11 +216,8 @@ def find_entries(arg, active_filter=None): def retrieve(self, request, pk=None, **kwargs): """Retrieve a single route.""" - entries = self.find_entries(pk, active_filter=True) - # TODO: What happens if we get multiple? Is that ok? I think yes, and return them all? - if entries.count() != 1: - raise Http404 - serializer = EntrySerializer(entries, many=True, context={"request": request}) + entry = self.get_object() + serializer = EntrySerializer(entry, context={"request": request}) return Response(serializer.data) def destroy(self, request, pk=None, *args, **kwargs): diff --git a/scram/route_manager/tests/acceptance/features/add_automated_block_entry.feature b/scram/route_manager/tests/acceptance/features/add_automated_block_entry.feature index 83af7dc5..c9e5ad82 100644 --- a/scram/route_manager/tests/acceptance/features/add_automated_block_entry.feature +++ b/scram/route_manager/tests/acceptance/features/add_automated_block_entry.feature @@ -33,7 +33,15 @@ Feature: an automated source adds a block entry When we're logged in And we add the entry 192.0.2.133 with comment it's coming from inside the house Then we get a 201 status code - And the change entry for 192.0.2.133 is it's coming from inside the house + And the comment for entry 192.0.2.133 is it's coming from inside the house + + @history: + Scenario: Update comment on a block entry + Given a client with block authorization + When we're logged in + And we add the entry 192.0.2.10 with comment it's coming from inside the house + Then we get a 201 status code + And we update the entry 192.0.2.10 with comment it's coming from outside the house Scenario Outline: add a block entry multiple times and it's accepted Given a client with block authorization diff --git a/scram/route_manager/tests/acceptance/steps/common.py b/scram/route_manager/tests/acceptance/steps/common.py index cb803ba6..cec0692d 100644 --- a/scram/route_manager/tests/acceptance/steps/common.py +++ b/scram/route_manager/tests/acceptance/steps/common.py @@ -43,6 +43,7 @@ def create_authed_client(context, name): is_authorized=True, ) authorized_client.authorized_actiontypes.set([at]) + context.client = authorized_client @given("a client without {name} authorization") diff --git a/scram/route_manager/tests/acceptance/steps/ip.py b/scram/route_manager/tests/acceptance/steps/ip.py index b8fda111..5492f96f 100644 --- a/scram/route_manager/tests/acceptance/steps/ip.py +++ b/scram/route_manager/tests/acceptance/steps/ip.py @@ -1,6 +1,7 @@ """Define steps used for IP-related logic by the Behave tests.""" import ipaddress +import json from behave import then, when from django.urls import reverse @@ -39,17 +40,34 @@ def check_error(context): assert isinstance(context.queryException, ValueError) -@then("the change entry for {value:S} is {comment}") +@then("the comment for entry {value:S} is {comment}") def check_comment(context, value, comment): """Verify the comment for the Entry.""" try: objs = context.test.client.get(reverse("api:v1:entry-detail", args=[value])) - context.test.assertEqual(objs.json()[0]["comment"], comment) + context.test.assertEqual(objs.json()["comment"], comment) except ValueError as e: context.response = None context.queryException = e +@then("we update the entry {value:S} with comment {comment}") +def update_entry_comment(context, value, comment): + """Update the entry with a new comment.""" + data = {"comment": comment, "who": context.client.hostname} + + context.response = context.test.client.put( + reverse("api:v1:entry-detail", args=[value]), data=json.dumps(data), content_type="application/json" + ) + + +@then("the entry {value:S} comment is {comment}") +def check_entry_comment_not_equal(context, value, comment): + """Verify the comment was updated.""" + objs = context.test.client.get(reverse("api:v1:entry-detail", args=[value])) + context.test.assertEqual(objs.json()["comment"], comment) + + @when("we search for {ip}") def search_ip(context, ip): """Search our main search bar for an IP.""" diff --git a/scram/route_manager/views.py b/scram/route_manager/views.py index d6e341db..13871f9f 100644 --- a/scram/route_manager/views.py +++ b/scram/route_manager/views.py @@ -143,8 +143,6 @@ def add_entry(request): messages.add_message(request, messages.ERROR, "Permission Denied") else: messages.add_message(request, messages.WARNING, f"Something went wrong: {res.status_code}") - with transaction.atomic(): - home_page(request) return redirect("route_manager:home")