|
8 | 8 | import logging |
9 | 9 | from typing import Any |
10 | 10 |
|
| 11 | +from django.db.models import Count, Prefetch |
11 | 12 | from drf_spectacular.utils import ( |
12 | 13 | OpenApiExample, |
13 | 14 | OpenApiParameter, |
|
21 | 22 | from rest_framework.exceptions import NotFound |
22 | 23 | from rest_framework.response import Response |
23 | 24 |
|
| 25 | +from core.entity_extraction import ( |
| 26 | + accept_entity_candidate, |
| 27 | + merge_entity_candidate, |
| 28 | + reject_entity_candidate, |
| 29 | +) |
24 | 30 | from core.models import ( |
25 | 31 | BlueskyCredentials, |
26 | 32 | Content, |
27 | 33 | Entity, |
| 34 | + EntityCandidate, |
| 35 | + EntityMention, |
28 | 36 | IngestionRun, |
29 | 37 | Project, |
30 | 38 | ProjectConfig, |
|
35 | 43 | ) |
36 | 44 | from core.serializers import ( |
37 | 45 | ContentSerializer, |
| 46 | + EntityCandidateMergeSerializer, |
| 47 | + EntityCandidateSerializer, |
| 48 | + EntityMentionSummarySerializer, |
38 | 49 | EntitySerializer, |
39 | 50 | IngestionRunSerializer, |
40 | 51 | ProjectConfigSerializer, |
@@ -693,7 +704,111 @@ class EntityViewSet(ProjectOwnedQuerysetMixin, viewsets.ModelViewSet): |
693 | 704 | """Manage tracked entities associated with a project.""" |
694 | 705 |
|
695 | 706 | serializer_class = EntitySerializer |
696 | | - queryset = Entity.objects.select_related("project") |
| 707 | + queryset = ( |
| 708 | + Entity.objects.select_related("project") |
| 709 | + .annotate(mention_count=Count("mentions", distinct=True)) |
| 710 | + .prefetch_related( |
| 711 | + Prefetch( |
| 712 | + "mentions", |
| 713 | + queryset=EntityMention.objects.select_related("content").order_by( |
| 714 | + "-created_at" |
| 715 | + ), |
| 716 | + to_attr="prefetched_mentions", |
| 717 | + ) |
| 718 | + ) |
| 719 | + ) |
| 720 | + |
| 721 | + @extend_schema( |
| 722 | + summary="List entity mentions", |
| 723 | + description="Return the extracted mention history for one tracked entity inside the selected project.", |
| 724 | + request=None, |
| 725 | + responses={200: EntityMentionSummarySerializer(many=True), 403: AUTHENTICATION_REQUIRED_RESPONSE}, |
| 726 | + tags=["Entity Catalog"], |
| 727 | + ) |
| 728 | + @action(detail=True, methods=["get"], url_path="mentions") |
| 729 | + def mentions(self, request, *args, **kwargs): |
| 730 | + """Return the extracted mentions for the selected entity.""" |
| 731 | + |
| 732 | + entity = self.get_object() |
| 733 | + mentions = entity.mentions.select_related("content").order_by("-created_at") |
| 734 | + serializer = EntityMentionSummarySerializer(mentions, many=True) |
| 735 | + return Response(serializer.data) |
| 736 | + |
| 737 | + |
| 738 | +@document_project_owned_viewset( |
| 739 | + resource_plural="entity candidates", |
| 740 | + resource_singular="entity candidate", |
| 741 | + create_description="Entity candidates are created by the pipeline and can be reviewed through dedicated actions.", |
| 742 | + tag="Entity Catalog", |
| 743 | + action_overrides=build_crud_action_overrides( |
| 744 | + EntityCandidateSerializer, |
| 745 | + resource_plural="entity candidates for the selected project", |
| 746 | + resource_singular="entity candidate", |
| 747 | + ), |
| 748 | +) |
| 749 | +class EntityCandidateViewSet(ProjectOwnedQuerysetMixin, viewsets.ReadOnlyModelViewSet): |
| 750 | + """Inspect and resolve entity candidates surfaced by entity extraction.""" |
| 751 | + |
| 752 | + serializer_class = EntityCandidateSerializer |
| 753 | + queryset = EntityCandidate.objects.select_related( |
| 754 | + "project", "first_seen_in", "merged_into" |
| 755 | + ) |
| 756 | + |
| 757 | + @extend_schema( |
| 758 | + summary="Accept entity candidate", |
| 759 | + description="Promote a pending entity candidate into a tracked entity and backfill recent mentions.", |
| 760 | + request=None, |
| 761 | + responses={200: EntityCandidateSerializer, 403: AUTHENTICATION_REQUIRED_RESPONSE}, |
| 762 | + tags=["Entity Catalog"], |
| 763 | + ) |
| 764 | + @action(detail=True, methods=["post"], url_path="accept") |
| 765 | + def accept(self, request, *args, **kwargs): |
| 766 | + """Accept an entity candidate and return its updated representation.""" |
| 767 | + |
| 768 | + candidate = self.get_object() |
| 769 | + accept_entity_candidate(candidate) |
| 770 | + candidate.refresh_from_db() |
| 771 | + serializer = self.get_serializer(candidate) |
| 772 | + return Response(serializer.data) |
| 773 | + |
| 774 | + @extend_schema( |
| 775 | + summary="Reject entity candidate", |
| 776 | + description="Mark a pending entity candidate as rejected without creating a tracked entity.", |
| 777 | + request=None, |
| 778 | + responses={200: EntityCandidateSerializer, 403: AUTHENTICATION_REQUIRED_RESPONSE}, |
| 779 | + tags=["Entity Catalog"], |
| 780 | + ) |
| 781 | + @action(detail=True, methods=["post"], url_path="reject") |
| 782 | + def reject(self, request, *args, **kwargs): |
| 783 | + """Reject an entity candidate and return its updated representation.""" |
| 784 | + |
| 785 | + candidate = self.get_object() |
| 786 | + reject_entity_candidate(candidate) |
| 787 | + candidate.refresh_from_db() |
| 788 | + serializer = self.get_serializer(candidate) |
| 789 | + return Response(serializer.data) |
| 790 | + |
| 791 | + @extend_schema( |
| 792 | + summary="Merge entity candidate", |
| 793 | + description="Merge a pending entity candidate into an existing tracked entity from the same project.", |
| 794 | + request=EntityCandidateMergeSerializer, |
| 795 | + responses={200: EntityCandidateSerializer, 400: EntityCandidateMergeSerializer, 403: AUTHENTICATION_REQUIRED_RESPONSE}, |
| 796 | + tags=["Entity Catalog"], |
| 797 | + ) |
| 798 | + @action(detail=True, methods=["post"], url_path="merge") |
| 799 | + def merge(self, request, *args, **kwargs): |
| 800 | + """Merge an entity candidate into an existing tracked entity.""" |
| 801 | + |
| 802 | + candidate = self.get_object() |
| 803 | + serializer = EntityCandidateMergeSerializer( |
| 804 | + data=request.data, |
| 805 | + context=self.get_serializer_context(), |
| 806 | + ) |
| 807 | + serializer.is_valid(raise_exception=True) |
| 808 | + merge_entity_candidate(candidate, serializer.validated_data["merged_into"]) |
| 809 | + candidate.refresh_from_db() |
| 810 | + response_serializer = self.get_serializer(candidate) |
| 811 | + return Response(response_serializer.data) |
697 | 812 |
|
698 | 813 |
|
699 | 814 | @document_project_owned_viewset( |
|
0 commit comments