From f7a2063c37c26b19ec1ce6dcd1e1840eb2b7a7b9 Mon Sep 17 00:00:00 2001 From: lsabor Date: Sat, 28 Mar 2026 09:51:33 -0700 Subject: [PATCH 1/7] add bulk forecast and comment api endpoint --- questions/urls.py | 5 ++ questions/views.py | 138 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 139 insertions(+), 4 deletions(-) diff --git a/questions/urls.py b/questions/urls.py index a497495c05..5d46e90d73 100644 --- a/questions/urls.py +++ b/questions/urls.py @@ -13,6 +13,11 @@ views.bulk_withdraw_forecasts_api_view, name="create-withdraw", ), + path( + "questions/bulk-forecast-comment/", + views.bulk_forecast_and_comment_api_view, + name="bulk-forecast-comment", + ), path( "questions//", views.question_detail_api_view, name="question-details" ), diff --git a/questions/views.py b/questions/views.py index 370ce30299..4f3c867011 100644 --- a/questions/views.py +++ b/questions/views.py @@ -1,17 +1,23 @@ +from django.db import transaction from django.http import Http404 from django.utils import timezone -from rest_framework import status +from rest_framework import serializers, status from rest_framework.decorators import api_view, permission_classes -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import PermissionDenied, ValidationError from rest_framework.generics import get_object_or_404 -from rest_framework.permissions import AllowAny +from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response from rest_framework.serializers import DateTimeField +from comments.serializers.common import CommentWriteSerializer +from comments.services.common import create_comment +from comments.services.key_factors.common import create_key_factors from posts.models import Post from posts.services.common import get_post_permission_for_user +from posts.tasks import run_on_post_forecast from posts.utils import get_post_slug from projects.permissions import ObjectPermission +from users.models import User from .constants import QuestionStatus from .models import Question from .serializers.common import ( @@ -21,7 +27,13 @@ ForecastWithdrawSerializer, serialize_question, ) -from .services.forecasts import create_forecast_bulk, withdraw_forecast_bulk +from .services.forecasts import ( + after_forecast_actions, + create_forecast, + create_forecast_bulk, + update_forecast_notification, + withdraw_forecast_bulk, +) from .services.lifecycle import resolve_question, unresolve_question @@ -246,3 +258,121 @@ def legacy_question_api_view(request, pk: int): return Response( {"question_id": pk, "post_id": post.pk, "post_slug": get_post_slug(post)} ) + + +class BulkForecastAndCommentSerializer(serializers.Serializer): + user_id = serializers.IntegerField(required=True) + is_staff_override = serializers.BooleanField(required=False, default=False) + forecasts = ForecastWriteSerializer(many=True, required=False, default=list) + comments = CommentWriteSerializer(many=True, required=False, default=list) + + +@api_view(["POST"]) +@permission_classes([IsAuthenticated]) +def bulk_forecast_and_comment_api_view(request): + """ + Submits forecasts and comments in a single atomic transaction. + + Staff users may submit on behalf of any user by providing user_id and + flag `is_staff_override`. + Non-staff users may only submit as themselves (user_id must match + the authenticated user's ID). + """ + serializer = BulkForecastAndCommentSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + data = serializer.validated_data + + user_id = data.get("user_id") + forecasts_data = data["forecasts"] + comments_data = data.get("comments", []) + is_staff_override = data.get("is_staff_override", False) + + request_user = request.user + if is_staff_override and not request_user.is_staff: + raise PermissionDenied("Non-staff users cannot use the is_staff_override flag.") + if not is_staff_override and user_id != request_user.id: + raise PermissionDenied( + "Non-staff users can only submit forecasts and comments as themselves." + ) + + if is_staff_override: + user = get_object_or_404(User, id=user_id) + else: + user = request_user + + now = timezone.now() + + # Validate forecasts and resolve question IDs to Question objects + questions_map = { + q.pk: q + for q in Question.objects.filter( + pk__in=[f["question"] for f in forecasts_data] + ).select_related("post") + } + + for forecast in forecasts_data: + question = questions_map.get(forecast["question"]) + if not question: + raise ValidationError(f"Wrong question id {forecast['question']}") + forecast["question"] = question + + permission = get_post_permission_for_user(question.get_post(), user=user) + ObjectPermission.can_forecast(permission, raise_exception=True) + + if not question.open_time or question.open_time > now: + raise ValidationError( + f"Question {question.id} is not open for forecasting yet" + ) + if (question.scheduled_close_time < now) or ( + question.actual_close_time and question.actual_close_time < now + ): + raise ValidationError( + f"Question {question.id} is already closed to forecasting" + ) + + # Validate comment permissions + for comment in comments_data: + on_post = comment["on_post"] + parent = comment.get("parent") + permission = get_post_permission_for_user( + parent.on_post if parent else on_post, user=user + ) + ObjectPermission.can_comment(permission, raise_exception=True) + + posts = set() + created_forecasts = [] + + with transaction.atomic(): + for forecast_data in forecasts_data: + question = forecast_data.pop("question") + posts.add(question.get_post()) + forecast = create_forecast(question=question, user=user, **forecast_data) + created_forecasts.append((forecast, question)) + + for comment_data in comments_data: + on_post = comment_data["on_post"] + included_forecast_flag = comment_data.pop("included_forecast", False) + key_factors = comment_data.pop("key_factors", None) + + included_forecast = ( + on_post.question.user_forecasts.filter(author_id=user.id) + .order_by("-start_time") + .first() + if included_forecast_flag and on_post.question_id + else None + ) + + new_comment = create_comment( + **comment_data, included_forecast=included_forecast, user=user + ) + if key_factors: + create_key_factors(new_comment, key_factors) + + for forecast, question in created_forecasts: + update_forecast_notification(forecast=forecast, created=True) + after_forecast_actions(question, user) + + for post in posts: + run_on_post_forecast.send_with_options(args=(post.id,), delay=10_000) + + return Response({}, status=status.HTTP_201_CREATED) From f9f95f35674d966a644d7f499f53ab01f639702d Mon Sep 17 00:00:00 2001 From: lsabor Date: Thu, 7 May 2026 10:38:36 -0700 Subject: [PATCH 2/7] unit tests --- questions/views.py | 39 +++- .../test_bulk_forecast_and_comment.py | 191 ++++++++++++++++++ 2 files changed, 219 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_questions/test_bulk_forecast_and_comment.py diff --git a/questions/views.py b/questions/views.py index e00811429c..a1c158c4b9 100644 --- a/questions/views.py +++ b/questions/views.py @@ -268,11 +268,19 @@ def legacy_question_api_view(request, pk: int): class BulkForecastAndCommentSerializer(serializers.Serializer): - user_id = serializers.IntegerField(required=True) + user_id = serializers.IntegerField(required=False, allow_null=True) + username = serializers.CharField(required=False, allow_null=True) is_staff_override = serializers.BooleanField(required=False, default=False) forecasts = ForecastWriteSerializer(many=True, required=False, default=list) comments = CommentWriteSerializer(many=True, required=False, default=list) + def validate(self, attrs): + if not attrs.get("user_id") and not attrs.get("username"): + raise serializers.ValidationError( + "Either user_id or username must be provided." + ) + return attrs + @api_view(["POST"]) @permission_classes([IsAuthenticated]) @@ -280,16 +288,17 @@ def bulk_forecast_and_comment_api_view(request): """ Submits forecasts and comments in a single atomic transaction. - Staff users may submit on behalf of any user by providing user_id and - flag `is_staff_override`. - Non-staff users may only submit as themselves (user_id must match - the authenticated user's ID). + Staff users may submit on behalf of any user by providing user_id or username + and flag `is_staff_override`. + Non-staff users may submit as themselves or as one of their bots (identified + by user_id or username). """ serializer = BulkForecastAndCommentSerializer(data=request.data) serializer.is_valid(raise_exception=True) data = serializer.validated_data user_id = data.get("user_id") + username = data.get("username") forecasts_data = data["forecasts"] comments_data = data.get("comments", []) is_staff_override = data.get("is_staff_override", False) @@ -297,15 +306,23 @@ def bulk_forecast_and_comment_api_view(request): request_user = request.user if is_staff_override and not request_user.is_staff: raise PermissionDenied("Non-staff users cannot use the is_staff_override flag.") - if not is_staff_override and user_id != request_user.id: - raise PermissionDenied( - "Non-staff users can only submit forecasts and comments as themselves." - ) - if is_staff_override: + if user_id: user = get_object_or_404(User, id=user_id) else: - user = request_user + user = get_object_or_404(User, username=username) + + if not is_staff_override: + is_self = user.id == request_user.id + is_own_bot = ( + user.is_bot + and user.bot_owner_id is not None + and user.bot_owner_id == request_user.id + ) + if not is_self and not is_own_bot: + raise PermissionDenied( + "Non-staff users can only submit forecasts and comments as themselves or their bots." + ) now = timezone.now() diff --git a/tests/unit/test_questions/test_bulk_forecast_and_comment.py b/tests/unit/test_questions/test_bulk_forecast_and_comment.py new file mode 100644 index 0000000000..c44301dbf8 --- /dev/null +++ b/tests/unit/test_questions/test_bulk_forecast_and_comment.py @@ -0,0 +1,191 @@ +import json +from datetime import datetime, timezone as dt_timezone + +import pytest +from rest_framework.reverse import reverse + +from questions.models import Forecast, Question +from tests.unit.test_posts.conftest import * # noqa +from tests.unit.test_posts.factories import factory_post +from tests.unit.test_questions.conftest import * # noqa +from tests.unit.test_questions.factories import create_question +from users.models import User + +URL = reverse("bulk-forecast-comment") + + +def forecast_payload(question, **kwargs): + return {"question": question.id, "probability_yes": 0.6, **kwargs} + + +@pytest.fixture() +def open_question(): + question = create_question( + question_type=Question.QuestionType.BINARY, + open_time=datetime(2000, 1, 1, tzinfo=dt_timezone.utc), + scheduled_close_time=datetime(3000, 1, 1, tzinfo=dt_timezone.utc), + ) + factory_post(question=question) + return question + + +@pytest.fixture() +def user_staff(db) -> User: + return User.objects.create( + email="staff@metaculus.com", username="staff_user", is_staff=True + ) + + +@pytest.fixture() +def user_bot(user1: User) -> User: + return User.objects.create( + email="bot@metaculus.com", + username="bot_user", + is_bot=True, + bot_owner=user1, + ) + + +@pytest.fixture() +def user_bot_no_owner() -> User: + return User.objects.create( + email="orphan_bot@metaculus.com", + username="orphan_bot", + is_bot=True, + bot_owner=None, + ) + + +class TestBulkForecastAndComment: + def test_requires_user_id_or_username(self, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps({"forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 400 + + def test_unauthenticated(self, anon_client, user1, open_question): + response = anon_client.post( + URL, + data=json.dumps({"user_id": user1.id, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_submit_as_self_by_user_id(self, user1, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps({"user_id": user1.id, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user1).exists() + + def test_submit_as_self_by_username(self, user1, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps({"username": user1.username, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user1).exists() + + def test_submit_as_other_user_denied(self, user1_client, user2, open_question): + response = user1_client.post( + URL, + data=json.dumps({"user_id": user2.id, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_submit_as_own_bot_by_user_id(self, user1_client, user_bot, open_question): + response = user1_client.post( + URL, + data=json.dumps({"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user_bot).exists() + + def test_submit_as_own_bot_by_username(self, user1_client, user_bot, open_question): + response = user1_client.post( + URL, + data=json.dumps({"username": user_bot.username, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user_bot).exists() + + def test_submit_as_other_users_bot_denied(self, user2_client, user_bot, open_question): + # user_bot is owned by user1, not user2 + response = user2_client.post( + URL, + data=json.dumps({"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_submit_as_bot_with_no_owner_denied(self, user1_client, user_bot_no_owner, open_question): + response = user1_client.post( + URL, + data=json.dumps({"user_id": user_bot_no_owner.id, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_staff_override_by_user_id(self, create_client_for_user, user_staff, user2, open_question): + staff_client = create_client_for_user(user_staff) + response = staff_client.post( + URL, + data=json.dumps({ + "user_id": user2.id, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + }), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user2).exists() + + def test_staff_override_by_username(self, create_client_for_user, user_staff, user2, open_question): + staff_client = create_client_for_user(user_staff) + response = staff_client.post( + URL, + data=json.dumps({ + "username": user2.username, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + }), + content_type="application/json", + ) + assert response.status_code == 201 + assert Forecast.objects.filter(question=open_question, author=user2).exists() + + def test_non_staff_cannot_use_staff_override(self, user1_client, user1, user2, open_question): + response = user1_client.post( + URL, + data=json.dumps({ + "user_id": user2.id, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + }), + content_type="application/json", + ) + assert response.status_code == 403 + + def test_unknown_user_id_returns_404(self, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps({"user_id": 999999, "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 404 + + def test_unknown_username_returns_404(self, user1_client, open_question): + response = user1_client.post( + URL, + data=json.dumps({"username": "does_not_exist", "forecasts": [forecast_payload(open_question)]}), + content_type="application/json", + ) + assert response.status_code == 404 From 001db864079ee49a33a58cf2bb13b358f7a59222 Mon Sep 17 00:00:00 2001 From: lsabor Date: Thu, 7 May 2026 10:47:40 -0700 Subject: [PATCH 3/7] ruff --- .../test_bulk_forecast_and_comment.py | 108 +++++++++++++----- 1 file changed, 78 insertions(+), 30 deletions(-) diff --git a/tests/unit/test_questions/test_bulk_forecast_and_comment.py b/tests/unit/test_questions/test_bulk_forecast_and_comment.py index c44301dbf8..509376bff5 100644 --- a/tests/unit/test_questions/test_bulk_forecast_and_comment.py +++ b/tests/unit/test_questions/test_bulk_forecast_and_comment.py @@ -68,7 +68,9 @@ def test_requires_user_id_or_username(self, user1_client, open_question): def test_unauthenticated(self, anon_client, user1, open_question): response = anon_client.post( URL, - data=json.dumps({"user_id": user1.id, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + {"user_id": user1.id, "forecasts": [forecast_payload(open_question)]} + ), content_type="application/json", ) assert response.status_code == 403 @@ -76,7 +78,9 @@ def test_unauthenticated(self, anon_client, user1, open_question): def test_submit_as_self_by_user_id(self, user1, user1_client, open_question): response = user1_client.post( URL, - data=json.dumps({"user_id": user1.id, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + {"user_id": user1.id, "forecasts": [forecast_payload(open_question)]} + ), content_type="application/json", ) assert response.status_code == 201 @@ -85,7 +89,12 @@ def test_submit_as_self_by_user_id(self, user1, user1_client, open_question): def test_submit_as_self_by_username(self, user1, user1_client, open_question): response = user1_client.post( URL, - data=json.dumps({"username": user1.username, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + { + "username": user1.username, + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 201 @@ -94,7 +103,9 @@ def test_submit_as_self_by_username(self, user1, user1_client, open_question): def test_submit_as_other_user_denied(self, user1_client, user2, open_question): response = user1_client.post( URL, - data=json.dumps({"user_id": user2.id, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + {"user_id": user2.id, "forecasts": [forecast_payload(open_question)]} + ), content_type="application/json", ) assert response.status_code == 403 @@ -102,7 +113,9 @@ def test_submit_as_other_user_denied(self, user1_client, user2, open_question): def test_submit_as_own_bot_by_user_id(self, user1_client, user_bot, open_question): response = user1_client.post( URL, - data=json.dumps({"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + {"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]} + ), content_type="application/json", ) assert response.status_code == 201 @@ -111,65 +124,93 @@ def test_submit_as_own_bot_by_user_id(self, user1_client, user_bot, open_questio def test_submit_as_own_bot_by_username(self, user1_client, user_bot, open_question): response = user1_client.post( URL, - data=json.dumps({"username": user_bot.username, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + { + "username": user_bot.username, + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 201 assert Forecast.objects.filter(question=open_question, author=user_bot).exists() - def test_submit_as_other_users_bot_denied(self, user2_client, user_bot, open_question): + def test_submit_as_other_users_bot_denied( + self, user2_client, user_bot, open_question + ): # user_bot is owned by user1, not user2 response = user2_client.post( URL, - data=json.dumps({"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + {"user_id": user_bot.id, "forecasts": [forecast_payload(open_question)]} + ), content_type="application/json", ) assert response.status_code == 403 - def test_submit_as_bot_with_no_owner_denied(self, user1_client, user_bot_no_owner, open_question): + def test_submit_as_bot_with_no_owner_denied( + self, user1_client, user_bot_no_owner, open_question + ): response = user1_client.post( URL, - data=json.dumps({"user_id": user_bot_no_owner.id, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + { + "user_id": user_bot_no_owner.id, + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 403 - def test_staff_override_by_user_id(self, create_client_for_user, user_staff, user2, open_question): + def test_staff_override_by_user_id( + self, create_client_for_user, user_staff, user2, open_question + ): staff_client = create_client_for_user(user_staff) response = staff_client.post( URL, - data=json.dumps({ - "user_id": user2.id, - "is_staff_override": True, - "forecasts": [forecast_payload(open_question)], - }), + data=json.dumps( + { + "user_id": user2.id, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 201 assert Forecast.objects.filter(question=open_question, author=user2).exists() - def test_staff_override_by_username(self, create_client_for_user, user_staff, user2, open_question): + def test_staff_override_by_username( + self, create_client_for_user, user_staff, user2, open_question + ): staff_client = create_client_for_user(user_staff) response = staff_client.post( URL, - data=json.dumps({ - "username": user2.username, - "is_staff_override": True, - "forecasts": [forecast_payload(open_question)], - }), + data=json.dumps( + { + "username": user2.username, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 201 assert Forecast.objects.filter(question=open_question, author=user2).exists() - def test_non_staff_cannot_use_staff_override(self, user1_client, user1, user2, open_question): + def test_non_staff_cannot_use_staff_override( + self, user1_client, user1, user2, open_question + ): response = user1_client.post( URL, - data=json.dumps({ - "user_id": user2.id, - "is_staff_override": True, - "forecasts": [forecast_payload(open_question)], - }), + data=json.dumps( + { + "user_id": user2.id, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 403 @@ -177,7 +218,9 @@ def test_non_staff_cannot_use_staff_override(self, user1_client, user1, user2, o def test_unknown_user_id_returns_404(self, user1_client, open_question): response = user1_client.post( URL, - data=json.dumps({"user_id": 999999, "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + {"user_id": 999999, "forecasts": [forecast_payload(open_question)]} + ), content_type="application/json", ) assert response.status_code == 404 @@ -185,7 +228,12 @@ def test_unknown_user_id_returns_404(self, user1_client, open_question): def test_unknown_username_returns_404(self, user1_client, open_question): response = user1_client.post( URL, - data=json.dumps({"username": "does_not_exist", "forecasts": [forecast_payload(open_question)]}), + data=json.dumps( + { + "username": "does_not_exist", + "forecasts": [forecast_payload(open_question)], + } + ), content_type="application/json", ) assert response.status_code == 404 From 565a72d35047080ac1fc188ad24c64ef04893d8b Mon Sep 17 00:00:00 2001 From: lsabor Date: Thu, 7 May 2026 10:51:21 -0700 Subject: [PATCH 4/7] protect against iterating user ids/usernames --- questions/views.py | 23 ++++++++++++------- .../test_bulk_forecast_and_comment.py | 23 ++++++++++++++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/questions/views.py b/questions/views.py index a1c158c4b9..5534e32c5b 100644 --- a/questions/views.py +++ b/questions/views.py @@ -307,21 +307,28 @@ def bulk_forecast_and_comment_api_view(request): if is_staff_override and not request_user.is_staff: raise PermissionDenied("Non-staff users cannot use the is_staff_override flag.") - if user_id: - user = get_object_or_404(User, id=user_id) + if is_staff_override: + if user_id: + user = get_object_or_404(User, id=user_id) + else: + user = get_object_or_404(User, username=username) else: - user = get_object_or_404(User, username=username) - - if not is_staff_override: - is_self = user.id == request_user.id + user = ( + User.objects.filter(id=user_id).first() + if user_id + else User.objects.filter(username=username).first() + ) + is_self = user is not None and user.id == request_user.id is_own_bot = ( - user.is_bot + user is not None + and user.is_bot and user.bot_owner_id is not None and user.bot_owner_id == request_user.id ) if not is_self and not is_own_bot: raise PermissionDenied( - "Non-staff users can only submit forecasts and comments as themselves or their bots." + "Non-staff users can only submit forecasts and comments as themselves " + "or their bots." ) now = timezone.now() diff --git a/tests/unit/test_questions/test_bulk_forecast_and_comment.py b/tests/unit/test_questions/test_bulk_forecast_and_comment.py index 509376bff5..9f469289e9 100644 --- a/tests/unit/test_questions/test_bulk_forecast_and_comment.py +++ b/tests/unit/test_questions/test_bulk_forecast_and_comment.py @@ -215,7 +215,7 @@ def test_non_staff_cannot_use_staff_override( ) assert response.status_code == 403 - def test_unknown_user_id_returns_404(self, user1_client, open_question): + def test_unknown_user_id_returns_403(self, user1_client, open_question): response = user1_client.post( URL, data=json.dumps( @@ -223,9 +223,9 @@ def test_unknown_user_id_returns_404(self, user1_client, open_question): ), content_type="application/json", ) - assert response.status_code == 404 + assert response.status_code == 403 - def test_unknown_username_returns_404(self, user1_client, open_question): + def test_unknown_username_returns_403(self, user1_client, open_question): response = user1_client.post( URL, data=json.dumps( @@ -236,4 +236,21 @@ def test_unknown_username_returns_404(self, user1_client, open_question): ), content_type="application/json", ) + assert response.status_code == 403 + + def test_staff_override_unknown_user_id_returns_404( + self, create_client_for_user, user_staff, open_question + ): + staff_client = create_client_for_user(user_staff) + response = staff_client.post( + URL, + data=json.dumps( + { + "user_id": 999999, + "is_staff_override": True, + "forecasts": [forecast_payload(open_question)], + } + ), + content_type="application/json", + ) assert response.status_code == 404 From 5a32474f67ac0ed1cab5a93736556eda5cb661c6 Mon Sep 17 00:00:00 2001 From: lsabor Date: Fri, 8 May 2026 09:28:04 -0700 Subject: [PATCH 5/7] send back all errors, switch to superuser override, block public commenting, allow pre-registering forecasts --- questions/services/forecasts.py | 2 +- questions/views.py | 87 ++++++++++--------- .../test_bulk_forecast_and_comment.py | 48 ++++++---- 3 files changed, 77 insertions(+), 60 deletions(-) diff --git a/questions/services/forecasts.py b/questions/services/forecasts.py index c5caa6abd2..d7a140f0ab 100644 --- a/questions/services/forecasts.py +++ b/questions/services/forecasts.py @@ -155,7 +155,7 @@ def after_forecast_actions(question: Question, user: User): ) # Run async tasks - from ..tasks import run_build_question_forecasts + from questions.tasks import run_build_question_forecasts run_build_question_forecasts.send(question.id) diff --git a/questions/views.py b/questions/views.py index 5534e32c5b..12f1d8d7e0 100644 --- a/questions/views.py +++ b/questions/views.py @@ -14,10 +14,8 @@ from comments.serializers.common import CommentWriteSerializer from comments.services.common import create_comment -from comments.services.key_factors.common import create_key_factors from posts.models import Post from posts.services.common import get_post_permission_for_user -from posts.tasks import run_on_post_forecast from posts.utils import get_post_slug from projects.permissions import ObjectPermission from users.models import User @@ -33,10 +31,7 @@ serialize_question, ) from .services.forecasts import ( - after_forecast_actions, - create_forecast, create_forecast_bulk, - update_forecast_notification, withdraw_forecast_bulk, ) from .services.lifecycle import resolve_question, unresolve_question @@ -288,9 +283,9 @@ def bulk_forecast_and_comment_api_view(request): """ Submits forecasts and comments in a single atomic transaction. - Staff users may submit on behalf of any user by providing user_id or username + Superusers may submit on behalf of any user by providing user_id or username and flag `is_staff_override`. - Non-staff users may submit as themselves or as one of their bots (identified + Non-superusers may submit as themselves or as one of their bots (identified by user_id or username). """ serializer = BulkForecastAndCommentSerializer(data=request.data) @@ -304,8 +299,8 @@ def bulk_forecast_and_comment_api_view(request): is_staff_override = data.get("is_staff_override", False) request_user = request.user - if is_staff_override and not request_user.is_staff: - raise PermissionDenied("Non-staff users cannot use the is_staff_override flag.") + if is_staff_override and not request_user.is_superuser: + raise PermissionDenied("Only superusers can use the is_staff_override flag.") if is_staff_override: if user_id: @@ -327,11 +322,12 @@ def bulk_forecast_and_comment_api_view(request): ) if not is_self and not is_own_bot: raise PermissionDenied( - "Non-staff users can only submit forecasts and comments as themselves " + "Non-superusers can only submit forecasts and comments as themselves " "or their bots." ) now = timezone.now() + errors = [] # Validate forecasts and resolve question IDs to Question objects questions_map = { @@ -342,48 +338,62 @@ def bulk_forecast_and_comment_api_view(request): } for forecast in forecasts_data: - question = questions_map.get(forecast["question"]) + question_id = forecast["question"] + question = questions_map.get(question_id) if not question: - raise ValidationError(f"Wrong question id {forecast['question']}") + errors.append(f"Question {question_id} does not exist.") + continue forecast["question"] = question - permission = get_post_permission_for_user(question.get_post(), user=user) - ObjectPermission.can_forecast(permission, raise_exception=True) + post: Post = question.post + permission = get_post_permission_for_user(post, user=user) + if not ObjectPermission.can_forecast(permission): + errors.append(f"Question {question.id}: forecasting not permitted.") + continue - if not question.open_time or question.open_time > now: - raise ValidationError( - f"Question {question.id} is not open for forecasting yet" - ) - if (question.scheduled_close_time < now) or ( + if ( + not post.curation_status != Post.CurationStatus.APPROVED + or not question.open_time + or not question.scheduled_close_time + ): + errors.append(f"Question {question.id} is not open for forecasting yet.") + elif (question.scheduled_close_time < now) or ( question.actual_close_time and question.actual_close_time < now ): - raise ValidationError( - f"Question {question.id} is already closed to forecasting" - ) + errors.append(f"Question {question.id} is already closed to forecasting.") - # Validate comment permissions - for comment in comments_data: + # Validate comments + for i, comment in enumerate(comments_data): on_post = comment["on_post"] + if not comment.get("is_private"): + errors.append( + f"Comment {i}: only private comments are allowed in bulk submissions." + ) + continue + if comment.get("key_factors"): + errors.append( + f"Comment {i}: key_factors are not supported in bulk submissions." + ) + continue parent = comment.get("parent") permission = get_post_permission_for_user( parent.on_post if parent else on_post, user=user ) - ObjectPermission.can_comment(permission, raise_exception=True) + if not ObjectPermission.can_comment(permission): + errors.append( + f"Comment {i}: commenting not permitted on post {on_post.id}." + ) - posts = set() - created_forecasts = [] + if errors: + raise ValidationError(errors) with transaction.atomic(): - for forecast_data in forecasts_data: - question = forecast_data.pop("question") - posts.add(question.get_post()) - forecast = create_forecast(question=question, user=user, **forecast_data) - created_forecasts.append((forecast, question)) + create_forecast_bulk(user=user, forecasts=forecasts_data) for comment_data in comments_data: on_post = comment_data["on_post"] included_forecast_flag = comment_data.pop("included_forecast", False) - key_factors = comment_data.pop("key_factors", None) + comment_data.pop("key_factors", None) included_forecast = ( on_post.question.user_forecasts.filter(author_id=user.id) @@ -393,18 +403,9 @@ def bulk_forecast_and_comment_api_view(request): else None ) - new_comment = create_comment( + create_comment( **comment_data, included_forecast=included_forecast, user=user ) - if key_factors: - create_key_factors(new_comment, key_factors) - - for forecast, question in created_forecasts: - update_forecast_notification(forecast=forecast, created=True) - after_forecast_actions(question, user) - - for post in posts: - run_on_post_forecast.send_with_options(args=(post.id,), delay=10_000) return Response({}, status=status.HTTP_201_CREATED) diff --git a/tests/unit/test_questions/test_bulk_forecast_and_comment.py b/tests/unit/test_questions/test_bulk_forecast_and_comment.py index 9f469289e9..624c7024ff 100644 --- a/tests/unit/test_questions/test_bulk_forecast_and_comment.py +++ b/tests/unit/test_questions/test_bulk_forecast_and_comment.py @@ -29,12 +29,6 @@ def open_question(): return question -@pytest.fixture() -def user_staff(db) -> User: - return User.objects.create( - email="staff@metaculus.com", username="staff_user", is_staff=True - ) - @pytest.fixture() def user_bot(user1: User) -> User: @@ -163,10 +157,10 @@ def test_submit_as_bot_with_no_owner_denied( ) assert response.status_code == 403 - def test_staff_override_by_user_id( - self, create_client_for_user, user_staff, user2, open_question + def test_superuser_override_by_user_id( + self, create_client_for_user, user_admin, user2, open_question ): - staff_client = create_client_for_user(user_staff) + staff_client = create_client_for_user(user_admin) response = staff_client.post( URL, data=json.dumps( @@ -181,10 +175,10 @@ def test_staff_override_by_user_id( assert response.status_code == 201 assert Forecast.objects.filter(question=open_question, author=user2).exists() - def test_staff_override_by_username( - self, create_client_for_user, user_staff, user2, open_question + def test_superuser_override_by_username( + self, create_client_for_user, user_admin, user2, open_question ): - staff_client = create_client_for_user(user_staff) + staff_client = create_client_for_user(user_admin) response = staff_client.post( URL, data=json.dumps( @@ -199,7 +193,7 @@ def test_staff_override_by_username( assert response.status_code == 201 assert Forecast.objects.filter(question=open_question, author=user2).exists() - def test_non_staff_cannot_use_staff_override( + def test_non_superuser_cannot_use_staff_override( self, user1_client, user1, user2, open_question ): response = user1_client.post( @@ -238,10 +232,10 @@ def test_unknown_username_returns_403(self, user1_client, open_question): ) assert response.status_code == 403 - def test_staff_override_unknown_user_id_returns_404( - self, create_client_for_user, user_staff, open_question + def test_superuser_override_unknown_user_id_returns_404( + self, create_client_for_user, user_admin, open_question ): - staff_client = create_client_for_user(user_staff) + staff_client = create_client_for_user(user_admin) response = staff_client.post( URL, data=json.dumps( @@ -254,3 +248,25 @@ def test_staff_override_unknown_user_id_returns_404( content_type="application/json", ) assert response.status_code == 404 + + def test_key_factors_in_bulk_comment_returns_400( + self, user1, user1_client, open_question + ): + response = user1_client.post( + URL, + data=json.dumps( + { + "user_id": user1.id, + "comments": [ + { + "on_post": open_question.get_post().id, + "text": "test comment", + "is_private": True, + "key_factors": [{"text": "some factor", "is_positive": True}], + } + ], + } + ), + content_type="application/json", + ) + assert response.status_code == 400 From e9da585d670e5e25f725b7d215ded15d0f3dbb07 Mon Sep 17 00:00:00 2001 From: lsabor Date: Fri, 8 May 2026 09:59:17 -0700 Subject: [PATCH 6/7] ruff --- tests/unit/test_questions/test_bulk_forecast_and_comment.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_questions/test_bulk_forecast_and_comment.py b/tests/unit/test_questions/test_bulk_forecast_and_comment.py index 624c7024ff..cdfca0d856 100644 --- a/tests/unit/test_questions/test_bulk_forecast_and_comment.py +++ b/tests/unit/test_questions/test_bulk_forecast_and_comment.py @@ -29,7 +29,6 @@ def open_question(): return question - @pytest.fixture() def user_bot(user1: User) -> User: return User.objects.create( @@ -262,7 +261,9 @@ def test_key_factors_in_bulk_comment_returns_400( "on_post": open_question.get_post().id, "text": "test comment", "is_private": True, - "key_factors": [{"text": "some factor", "is_positive": True}], + "key_factors": [ + {"text": "some factor", "is_positive": True} + ], } ], } From 1552e67654885d936ada4e40f5227da36d50d413 Mon Sep 17 00:00:00 2001 From: lsabor Date: Fri, 8 May 2026 11:27:19 -0700 Subject: [PATCH 7/7] fix tests --- questions/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/questions/views.py b/questions/views.py index da4d94d904..7d3c987c06 100644 --- a/questions/views.py +++ b/questions/views.py @@ -366,7 +366,7 @@ def bulk_forecast_and_comment_api_view(request): continue if ( - not post.curation_status != Post.CurationStatus.APPROVED + post.curation_status != Post.CurationStatus.APPROVED or not question.open_time or not question.scheduled_close_time ):