diff --git a/questions/services/forecasts.py b/questions/services/forecasts.py index c5caa6abd2..d7a140f0ab 100644 --- a/questions/services/forecasts.py +++ b/questions/services/forecasts.py @@ -155,7 +155,7 @@ def after_forecast_actions(question: Question, user: User): ) # Run async tasks - from ..tasks import run_build_question_forecasts + from questions.tasks import run_build_question_forecasts run_build_question_forecasts.send(question.id) diff --git a/questions/urls.py b/questions/urls.py index 3b05b0fd17..af77f16b2a 100644 --- a/questions/urls.py +++ b/questions/urls.py @@ -13,6 +13,11 @@ views.bulk_withdraw_forecasts_api_view, name="create-withdraw", ), + path( + "questions/bulk-forecast-comment/", + views.bulk_forecast_and_comment_api_view, + name="bulk-forecast-comment", + ), path( "questions//", views.question_detail_api_view, name="question-details" ), diff --git a/questions/views.py b/questions/views.py index b04ab88292..d2b0fa7a66 100644 --- a/questions/views.py +++ b/questions/views.py @@ -1,23 +1,28 @@ -import numpy as np +from django.db import transaction from django.http import Http404 from django.utils import timezone -from rest_framework import status +import numpy as np +from rest_framework import serializers, status from rest_framework.decorators import api_view, permission_classes -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.generics import get_object_or_404 -from rest_framework.permissions import AllowAny, IsAdminUser +from rest_framework.permissions import AllowAny, IsAuthenticated, IsAdminUser from rest_framework.response import Response from rest_framework.serializers import DateTimeField + +from comments.serializers.common import CommentWriteSerializer +from comments.services.common import create_comment from posts.models import Post from posts.services.common import get_post_permission_for_user from posts.utils import get_post_slug from projects.permissions import ObjectPermission +from users.models import User from utils.requests import is_internal_request from utils.the_math.aggregations import get_aggregations_at_time -from .constants import QuestionStatus -from .models import Forecast, Question -from .serializers.common import ( +from questions.constants import QuestionStatus +from questions.models import Forecast, Question +from questions.serializers.common import ( validate_question_resolution, QuestionsCommunityPredictionsSerializer, OldForecastWriteSerializer, @@ -25,8 +30,11 @@ ForecastWithdrawSerializer, serialize_question, ) -from .services.forecasts import create_forecast_bulk, withdraw_forecast_bulk -from .services.lifecycle import resolve_question, unresolve_question +from questions.services.forecasts import ( + create_forecast_bulk, + withdraw_forecast_bulk, +) +from questions.services.lifecycle import resolve_question, unresolve_question @api_view(["GET"]) @@ -274,6 +282,154 @@ def legacy_question_api_view(request, pk: int): ) +class BulkForecastAndCommentSerializer(serializers.Serializer): + user_id = serializers.IntegerField(required=False, allow_null=True) + username = serializers.CharField(required=False, allow_null=True) + is_staff_override = serializers.BooleanField(required=False, default=False) + forecasts = ForecastWriteSerializer(many=True, required=False, default=list) + comments = CommentWriteSerializer(many=True, required=False, default=list) + + def validate(self, attrs): + if not attrs.get("user_id") and not attrs.get("username"): + raise serializers.ValidationError( + "Either user_id or username must be provided." + ) + return attrs + + +@api_view(["POST"]) +@permission_classes([IsAuthenticated]) +def bulk_forecast_and_comment_api_view(request): + """ + Submits forecasts and comments in a single atomic transaction. + + Superusers may submit on behalf of any user by providing user_id or username + and flag `is_staff_override`. + Non-superusers may submit as themselves or as one of their bots (identified + by user_id or username). + """ + serializer = BulkForecastAndCommentSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + data = serializer.validated_data + + user_id = data.get("user_id") + username = data.get("username") + forecasts_data = data["forecasts"] + comments_data = data.get("comments", []) + is_staff_override = data.get("is_staff_override", False) + + request_user = request.user + if is_staff_override and not request_user.is_superuser: + raise PermissionDenied("Only superusers can use the is_staff_override flag.") + + if is_staff_override: + if user_id: + user = get_object_or_404(User, id=user_id) + else: + user = get_object_or_404(User, username=username) + else: + user = ( + User.objects.filter(id=user_id).first() + if user_id + else User.objects.filter(username=username).first() + ) + is_self = user is not None and user.id == request_user.id + is_own_bot = ( + user is not None + and user.is_bot + and user.bot_owner_id is not None + and user.bot_owner_id == request_user.id + ) + if not is_self and not is_own_bot: + raise PermissionDenied( + "Non-superusers can only submit forecasts and comments as themselves " + "or their bots." + ) + + now = timezone.now() + errors = [] + + # Validate forecasts and resolve question IDs to Question objects + questions_map = { + q.pk: q + for q in Question.objects.filter( + pk__in=[f["question"] for f in forecasts_data] + ).select_related("post") + } + + for forecast in forecasts_data: + question_id = forecast["question"] + question = questions_map.get(question_id) + if not question: + errors.append(f"Question {question_id} does not exist.") + continue + forecast["question"] = question + + post: Post = question.post + permission = get_post_permission_for_user(post, user=user) + if not ObjectPermission.can_forecast(permission): + errors.append(f"Question {question.id}: forecasting not permitted.") + continue + + if ( + post.curation_status != Post.CurationStatus.APPROVED + or not question.open_time + or not question.scheduled_close_time + ): + errors.append(f"Question {question.id} is not open for forecasting yet.") + elif (question.scheduled_close_time < now) or ( + question.actual_close_time and question.actual_close_time < now + ): + errors.append(f"Question {question.id} is already closed to forecasting.") + + # Validate comments + for i, comment in enumerate(comments_data): + on_post = comment["on_post"] + if not comment.get("is_private"): + errors.append( + f"Comment {i}: only private comments are allowed in bulk submissions." + ) + continue + if comment.get("key_factors"): + errors.append( + f"Comment {i}: key_factors are not supported in bulk submissions." + ) + continue + parent = comment.get("parent") + permission = get_post_permission_for_user( + parent.on_post if parent else on_post, user=user + ) + if not ObjectPermission.can_comment(permission): + errors.append( + f"Comment {i}: commenting not permitted on post {on_post.id}." + ) + + if errors: + raise ValidationError(errors) + + with transaction.atomic(): + create_forecast_bulk(user=user, forecasts=forecasts_data) + + for comment_data in comments_data: + on_post = comment_data["on_post"] + included_forecast_flag = comment_data.pop("included_forecast", False) + comment_data.pop("key_factors", None) + + included_forecast = ( + on_post.question.user_forecasts.filter(author_id=user.id) + .order_by("-start_time") + .first() + if included_forecast_flag and on_post.question_id + else None + ) + + create_comment( + **comment_data, included_forecast=included_forecast, user=user + ) + + return Response({}, status=status.HTTP_201_CREATED) + + @api_view(["GET", "POST"]) @permission_classes([IsAdminUser]) def questions_community_predictions(request) -> Response: diff --git a/tests/unit/test_questions/test_bulk_forecast_and_comment.py b/tests/unit/test_questions/test_bulk_forecast_and_comment.py new file mode 100644 index 0000000000..cdfca0d856 --- /dev/null +++ b/tests/unit/test_questions/test_bulk_forecast_and_comment.py @@ -0,0 +1,273 @@ +import json +from datetime import datetime, timezone as dt_timezone + +import pytest +from rest_framework.reverse import reverse + +from questions.models import Forecast, Question +from tests.unit.test_posts.conftest import * # noqa +from tests.unit.test_posts.factories import factory_post +from tests.unit.test_questions.conftest import * # noqa +from tests.unit.test_questions.factories import create_question +from users.models import User + +URL = reverse("bulk-forecast-comment") + + +def forecast_payload(question, **kwargs): + return {"question": question.id, "probability_yes": 0.6, **kwargs} + + +@pytest.fixture() +def open_question(): + question = create_question( + question_type=Question.QuestionType.BINARY, + open_time=datetime(2000, 1, 1, tzinfo=dt_timezone.utc), + scheduled_close_time=datetime(3000, 1, 1, tzinfo=dt_timezone.utc), + ) + factory_post(question=question) + return question + + +@pytest.fixture() +def user_bot(user1: User) -> User: + return User.objects.create( + email="bot@metaculus.com", + username="bot_user", + is_bot=True, + bot_owner=user1, + ) + + +@pytest.fixture() +def user_bot_no_owner() -> User: + return User.objects.create( + email="orphan_bot@metaculus.com", + username="orphan_bot", + is_bot=True, + bot_owner=None, + ) + + +class TestBulkForecastAndComment: + def test_requires_user_id_or_username(self, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps({"forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 400 + + def test_unauthenticated(self, anon_client, user1, open_question): + response = anon_client.post( + URL, + data=json.dumps( + {"user_id": user1.id, "forecasts": [forecast_payload(open_question)]} + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_submit_as_self_by_user_id(self, user1, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps( + {"user_id": user1.id, "forecasts": [forecast_payload(open_question)]} + ), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user1).exists() + + def test_submit_as_self_by_username(self, user1, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps( + { + "username": user1.username, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user1).exists() + + def test_submit_as_other_user_denied(self, user1_client, user2, open_question): + response = user1_client.post( + URL, + data=json.dumps( + {"user_id": user2.id, "forecasts": [forecast_payload(open_question)]} + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_submit_as_own_bot_by_user_id(self, user1_client, user_bot, open_question): + response = user1_client.post( + URL, + data=json.dumps( + {"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]} + ), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user_bot).exists() + + def test_submit_as_own_bot_by_username(self, user1_client, user_bot, open_question): + response = user1_client.post( + URL, + data=json.dumps( + { + "username": user_bot.username, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user_bot).exists() + + def test_submit_as_other_users_bot_denied( + self, user2_client, user_bot, open_question + ): + # user_bot is owned by user1, not user2 + response = user2_client.post( + URL, + data=json.dumps( + {"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]} + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_submit_as_bot_with_no_owner_denied( + self, user1_client, user_bot_no_owner, open_question + ): + response = user1_client.post( + URL, + data=json.dumps( + { + "user_id": user_bot_no_owner.id, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_superuser_override_by_user_id( + self, create_client_for_user, user_admin, user2, open_question + ): + staff_client = create_client_for_user(user_admin) + response = staff_client.post( + URL, + data=json.dumps( + { + "user_id": user2.id, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user2).exists() + + def test_superuser_override_by_username( + self, create_client_for_user, user_admin, user2, open_question + ): + staff_client = create_client_for_user(user_admin) + response = staff_client.post( + URL, + data=json.dumps( + { + "username": user2.username, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user2).exists() + + def test_non_superuser_cannot_use_staff_override( + self, user1_client, user1, user2, open_question + ): + response = user1_client.post( + URL, + data=json.dumps( + { + "user_id": user2.id, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_unknown_user_id_returns_403(self, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps( + {"user_id": 999999, "forecasts": [forecast_payload(open_question)]} + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_unknown_username_returns_403(self, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps( + { + "username": "does_not_exist", + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_superuser_override_unknown_user_id_returns_404( + self, create_client_for_user, user_admin, open_question + ): + staff_client = create_client_for_user(user_admin) + response = staff_client.post( + URL, + data=json.dumps( + { + "user_id": 999999, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) + assert response.status_code == 404 + + def test_key_factors_in_bulk_comment_returns_400( + self, user1, user1_client, open_question + ): + response = user1_client.post( + URL, + data=json.dumps( + { + "user_id": user1.id, + "comments": [ + { + "on_post": open_question.get_post().id, + "text": "test comment", + "is_private": True, + "key_factors": [ + {"text": "some factor", "is_positive": True} + ], + } + ], + } + ), + content_type="application/json", + ) + assert response.status_code == 400