diff --git a/config/api_router.py b/config/api_router.py index a57a12ba..e7f0ffbf 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -2,7 +2,13 @@ from rest_framework.routers import DefaultRouter -from scram.route_manager.api.views import ActionTypeViewSet, ClientViewSet, EntryViewSet, IgnoreEntryViewSet +from scram.route_manager.api.views import ( + ActionTypeViewSet, + ClientViewSet, + EntryViewSet, + IgnoreEntryViewSet, + IsActiveViewSet, +) from scram.users.api.views import UserViewSet router = DefaultRouter() @@ -12,7 +18,7 @@ router.register("register_client", ClientViewSet) router.register("entries", EntryViewSet) router.register("ignore_entries", IgnoreEntryViewSet) - +router.register("is_active", IsActiveViewSet, "is_active") app_name = "api" urlpatterns = router.urls diff --git a/scram/route_manager/api/serializers.py b/scram/route_manager/api/serializers.py index 5c790f54..2896b73c 100644 --- a/scram/route_manager/api/serializers.py +++ b/scram/route_manager/api/serializers.py @@ -52,6 +52,18 @@ class Meta: fields = ["hostname", "uuid"] +class IsActiveSerializer(serializers.ModelSerializer): + """Map the serializer to the Entry model.""" + + route = serializers.StringRelatedField(source="route.route") + + class Meta: + """Maps to the Entry model, but limits to the the appropriate fields.""" + + model = Entry + fields = ["is_active", "route"] + + class EntrySerializer(serializers.HyperlinkedModelSerializer): """Due to the use of ForeignKeys, this follows some relationships to make sense via the API.""" diff --git a/scram/route_manager/api/views.py b/scram/route_manager/api/views.py index d1da5f8d..b5d3d6f8 100644 --- a/scram/route_manager/api/views.py +++ b/scram/route_manager/api/views.py @@ -10,13 +10,20 @@ from django.db.models import Q from drf_spectacular.utils import extend_schema from rest_framework import status, viewsets +from rest_framework.exceptions import ValidationError from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from simple_history.utils import update_change_reason from ..models import ActionType, Client, Entry, IgnoreEntry, Route, WebSocketSequenceElement from .exceptions import ActiontypeNotAllowed, IgnoredRoute, NoActiveEntryFound, PrefixTooLarge -from .serializers import ActionTypeSerializer, ClientSerializer, EntrySerializer, IgnoreEntrySerializer +from .serializers import ( + ActionTypeSerializer, + ClientSerializer, + EntrySerializer, + IgnoreEntrySerializer, + IsActiveSerializer, +) channel_layer = get_channel_layer() logger = logging.getLogger(__name__) @@ -63,6 +70,51 @@ class ClientViewSet(viewsets.ModelViewSet): http_method_names = ["post"] +class IsActiveViewSet(viewsets.ReadOnlyModelViewSet): + """Look up a route to see if SCRAM considers it active or deactivated.""" + + serializer_class = IsActiveSerializer + permission_classes = (AllowAny,) + http_method_names = ["get"] + + normalization_warning: str | None + normalized_cidr_for_response: ipaddress.IPv4Network | ipaddress.IPv6Network | None + + def get_queryset(self): + """Focus queryset on active routes.""" + cidr = self.request.query_params.get("cidr") + if not cidr: + raise ValidationError(detail={"error": "cidr parameter is required"}) + try: + normalized_cidr = ipaddress.ip_network(cidr, strict=False) + except ValueError: + raise ValidationError(detail={"error": "invalid ip address or network"}) from None + + self.normalization_warning = None + self.normalized_cidr_for_response = normalized_cidr + + if str(cidr) != str(normalized_cidr): + # save the warning so we can use it in the list response + self.normalization_warning = ( + f"Input CIDR '{cidr}' was not canonical and was normalized to '{normalized_cidr!s}' for the search." + ) + + return Entry.objects.filter(route__route__net_contained_or_equal=normalized_cidr, is_active=True) + + def list(self, request): + """Override the list function to just return a boolean instead of other metadata.""" + queryset = self.get_queryset() + + if not queryset.exists() and hasattr(self, "normalized_cidr_for_response"): + response_data = {"results": [{"is_active": False, "route": str(self.normalized_cidr_for_response)}]} + else: + serializer = self.get_serializer(queryset, many=True) + response_data = {"results": serializer.data} + response_data["warning"] = self.normalization_warning + + return Response(response_data) + + @extend_schema( description="API endpoint for entries", responses={200: EntrySerializer}, diff --git a/scram/route_manager/tests/acceptance/steps/common.py b/scram/route_manager/tests/acceptance/steps/common.py index cec0692d..2d86fefc 100644 --- a/scram/route_manager/tests/acceptance/steps/common.py +++ b/scram/route_manager/tests/acceptance/steps/common.py @@ -9,12 +9,7 @@ from django import conf from django.urls import reverse -from scram.route_manager.models import ( - ActionType, - Client, - WebSocketMessage, - WebSocketSequenceElement, -) +from scram.route_manager.models import ActionType, Client, WebSocketMessage, WebSocketSequenceElement @given("a {name} actiontype is defined") diff --git a/scram/route_manager/tests/test_api.py b/scram/route_manager/tests/test_api.py index 1a81e907..dd980924 100644 --- a/scram/route_manager/tests/test_api.py +++ b/scram/route_manager/tests/test_api.py @@ -5,7 +5,7 @@ from rest_framework import status from rest_framework.test import APITestCase -from scram.route_manager.models import Client +from scram.route_manager.models import ActionType, Client, Entry, Route class TestAddRemoveIP(APITestCase): @@ -125,3 +125,84 @@ def test_unauthenticated_users_have_no_list_access(self): """Ensure an unauthenticated client can't list Entries.""" response = self.client.get(self.entry_url, format="json") self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +class TestIsActive(APITestCase): + """Test the is_active endpoint.""" + + def setUp(self): + """Set up test data.""" + self.url = reverse("api:v1:is_active-list") + self.authorized_client = Client.objects.create( + hostname="authorized_client.es.net", + uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3", + is_authorized=True, + ) + self.authorized_client.authorized_actiontypes.set([1]) + self.actiontype, _ = ActionType.objects.get_or_create(pk=1, defaults={"name": "block"}) + + # Create some active entries + + # Active IPv4 + route_v4 = Route.objects.create(route="192.0.2.100") + Entry.objects.create( + route=route_v4, is_active=True, comment="test active", who="test", actiontype=self.actiontype + ) + + # Active IPv6 + route_v6 = Route.objects.create(route="2001:db8::1") + Entry.objects.create( + route=route_v6, is_active=True, comment="test active v6", who="test", actiontype=self.actiontype + ) + + # Deactivated IPv4 entry + route_inactive = Route.objects.create(route="192.0.2.200") + Entry.objects.create( + route=route_inactive, is_active=False, comment="inactive", who="test", actiontype=self.actiontype + ) + + # Deactived IPv6 entry + route_inactive = Route.objects.create(route="2001:db8::5") + Entry.objects.create( + route=route_inactive, is_active=False, comment="inactive", who="test", actiontype=self.actiontype + ) + + def test_active_ipv4_returns_true(self): + """Check that an active IPv4 returns is_active=true.""" + response = self.client.get(self.url, {"cidr": "192.0.2.100"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["is_active"], True) + self.assertEqual(response.data["results"][0]["route"], "192.0.2.100/32") + + def test_active_ipv6_returns_true(self): + """Check that an active IPv6 returns is_active=true.""" + response = self.client.get(self.url, {"cidr": "2001:db8::1"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["is_active"], True) + self.assertEqual(response.data["results"][0]["route"], "2001:db8::1/128") + + def test_inactive_entry_ipv4_returns_false(self): + """Check that an inactive entry returns is_active=false.""" + response = self.client.get(self.url, {"cidr": "192.0.2.200"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["is_active"], False) + self.assertEqual(response.data["results"][0]["route"], "192.0.2.200/32") + + def test_inactive_entry_ipv6_returns_false(self): + """Check that an inactive entry returns is_active=false.""" + response = self.client.get(self.url, {"cidr": "2001:db8::5"}) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["is_active"], False) + self.assertEqual(response.data["results"][0]["route"], "2001:db8::5/128") + + def test_unauthenticated_access_allowed(self): + """Ensure unauthenticated clients can check if IPs are active.""" + # Logout any authenticated user + self.client.logout() + response = self.client.get(self.url, {"cidr": "192.0.2.100"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["is_active"], True)