diff --git a/posts/migrations/0030_backfill_conditional_categories.py b/posts/migrations/0030_backfill_conditional_categories.py new file mode 100644 index 0000000000..aaba7b5010 --- /dev/null +++ b/posts/migrations/0030_backfill_conditional_categories.py @@ -0,0 +1,59 @@ +from django.db import migrations + + +def backfill_conditional_categories(apps, schema_editor): + """ + Populate categories on conditional posts from their condition + and condition_child posts. + """ + Post = apps.get_model("posts", "Post") + + category_type = "category" + + conditional_posts = Post.objects.filter( + conditional__isnull=False, + ).select_related( + "conditional__condition", + "conditional__condition_child", + ) + + for post in conditional_posts: + conditional = post.conditional + categories_to_add = set() + + # Get categories from condition's post + condition_post_id = conditional.condition.post_id + if condition_post_id: + condition_post = Post.objects.get(pk=condition_post_id) + categories_to_add.update( + condition_post.projects.filter(type=category_type) + ) + + # Get categories from condition_child's post + child_post_id = conditional.condition_child.post_id + if child_post_id: + child_post = Post.objects.get(pk=child_post_id) + categories_to_add.update( + child_post.projects.filter(type=category_type) + ) + + if categories_to_add: + existing = set(post.projects.filter(type=category_type)) + new_categories = categories_to_add - existing + if new_categories: + post.projects.add(*new_categories) + + +class Migration(migrations.Migration): + + dependencies = [ + ("posts", "0029_remove_notebook_markdown_summary_and_more"), + ("projects", "0021_projectindex_project_index_projectindexpost"), + ] + + operations = [ + migrations.RunPython( + backfill_conditional_categories, + migrations.RunPython.noop, + ), + ] diff --git a/posts/services/common.py b/posts/services/common.py index 169c198772..a3b1ff3a46 100644 --- a/posts/services/common.py +++ b/posts/services/common.py @@ -51,6 +51,37 @@ logger = logging.getLogger(__name__) +def get_conditional_categories(conditional) -> list[Project]: + """Get the union of categories from a conditional's condition and condition_child posts.""" + category_type = Project.ProjectTypes.CATEGORY + categories = set() + + condition_post = conditional.condition.get_post() + if condition_post: + categories.update(condition_post.projects.filter(type=category_type)) + + condition_child_post = conditional.condition_child.get_post() + if condition_child_post: + categories.update(condition_child_post.projects.filter(type=category_type)) + + return list(categories) + + +def sync_conditional_categories(post: Post): + """Sync the categories of a conditional post with its parent/child question categories.""" + if not post.conditional_id: + return + + conditional_categories = get_conditional_categories(post.conditional) + if conditional_categories: + existing_categories = set( + post.projects.filter(type=Project.ProjectTypes.CATEGORY) + ) + new_categories = set(conditional_categories) - existing_categories + if new_categories: + post.projects.add(*new_categories) + + def add_categories(categories: list[int], post: Post): existing = [x.pk for x in post.projects.filter(type=Project.ProjectTypes.CATEGORY)] categories = [x for x in categories if x not in existing] @@ -170,6 +201,10 @@ def create_post( obj.projects.add(*categories) + # Propagate categories from condition and condition_child posts + if obj.conditional_id: + sync_conditional_categories(obj) + # Update global leaderboard tags update_global_leaderboard_tags(obj) @@ -285,6 +320,13 @@ def update_post( update_conditional(post.conditional, **conditional) + # Re-sync inherited categories for conditional posts after any update. + # This runs outside the `if conditional:` branch so that category-only + # edits (which replace the category set via post.projects.set above) + # don't silently drop categories inherited from the parent/child posts. + if post.conditional_id: + sync_conditional_categories(post) + if group_of_questions: if not post.group_of_questions: raise ValidationError("Original post does is not a group of questions") diff --git a/tests/unit/test_posts/test_services/test_conditional_categories.py b/tests/unit/test_posts/test_services/test_conditional_categories.py new file mode 100644 index 0000000000..ea2bb47edc --- /dev/null +++ b/tests/unit/test_posts/test_services/test_conditional_categories.py @@ -0,0 +1,156 @@ +from posts.services.common import ( + get_conditional_categories, + sync_conditional_categories, + update_post, +) +from projects.models import Project +from questions.models import Question +from tests.unit.test_posts.factories import factory_post +from tests.unit.test_projects.factories import factory_project +from tests.unit.test_questions.factories import create_conditional, create_question + + +class TestConditionalCategoryPropagation: + def _make_conditional_setup(self): + """Create condition and child questions with their own posts and categories.""" + cat_a = factory_project(type=Project.ProjectTypes.CATEGORY, name="Category A") + cat_b = factory_project(type=Project.ProjectTypes.CATEGORY, name="Category B") + cat_c = factory_project(type=Project.ProjectTypes.CATEGORY, name="Category C") + + condition = create_question( + question_type=Question.QuestionType.BINARY, + ) + condition_child = create_question( + question_type=Question.QuestionType.BINARY, + ) + + factory_post( + question=condition, + projects=[cat_a, cat_b], + ) + factory_post( + question=condition_child, + projects=[cat_b, cat_c], + ) + + conditional = create_conditional( + condition=condition, + condition_child=condition_child, + question_yes=create_question( + question_type=Question.QuestionType.BINARY, + title="If Yes", + ), + question_no=create_question( + question_type=Question.QuestionType.BINARY, + title="If No", + ), + ) + + return conditional, cat_a, cat_b, cat_c + + def test_get_conditional_categories(self): + conditional, cat_a, cat_b, cat_c = self._make_conditional_setup() + + categories = get_conditional_categories(conditional) + category_ids = {c.id for c in categories} + + assert cat_a.id in category_ids + assert cat_b.id in category_ids + assert cat_c.id in category_ids + + def test_sync_conditional_categories(self): + conditional, cat_a, cat_b, cat_c = self._make_conditional_setup() + + post = factory_post(conditional=conditional) + sync_conditional_categories(post) + + post_categories = set(post.projects.filter(type=Project.ProjectTypes.CATEGORY)) + assert cat_a in post_categories + assert cat_b in post_categories + assert cat_c in post_categories + + def test_sync_conditional_categories_no_duplicates(self): + conditional, cat_a, cat_b, cat_c = self._make_conditional_setup() + + post = factory_post(conditional=conditional, projects=[cat_a]) + sync_conditional_categories(post) + + cat_a_count = post.projects.filter( + type=Project.ProjectTypes.CATEGORY, id=cat_a.id + ).count() + assert cat_a_count == 1 + + def test_sync_noop_for_non_conditional(self): + question = create_question( + question_type=Question.QuestionType.BINARY, + ) + post = factory_post(question=question) + + # Should not raise + sync_conditional_categories(post) + + def test_update_post_categories_preserves_inherited(self, mocker): + """ + Regression: calling update_post(..., categories=[...]) on a conditional + post must not drop categories inherited from parent/child questions, + even when no `conditional` payload is passed. + """ + # Avoid downstream side effects that aren't relevant to this test + mocker.patch("posts.tasks.run_post_indexing.send") + + conditional, cat_a, cat_b, cat_c = self._make_conditional_setup() + post = factory_post(conditional=conditional) + sync_conditional_categories(post) + + # Sanity check: inherited categories are present before the update + initial_categories = set( + post.projects.filter(type=Project.ProjectTypes.CATEGORY) + ) + assert {cat_a, cat_b, cat_c} <= initial_categories + + # User edits the post with only cat_a in the categories payload + update_post(post, categories=[cat_a]) + + post_categories = set( + post.projects.filter(type=Project.ProjectTypes.CATEGORY) + ) + # Inherited categories must still be present + assert cat_a in post_categories + assert cat_b in post_categories + assert cat_c in post_categories + + # And cat_a should not have been duplicated + cat_a_count = post.projects.filter( + type=Project.ProjectTypes.CATEGORY, id=cat_a.id + ).count() + assert cat_a_count == 1 + + def test_get_conditional_categories_missing_posts(self): + """Categories should be collected even if one parent has no post.""" + condition = create_question( + question_type=Question.QuestionType.BINARY, + ) + condition_child = create_question( + question_type=Question.QuestionType.BINARY, + ) + + cat = factory_project(type=Project.ProjectTypes.CATEGORY, name="Only Cat") + factory_post(question=condition, projects=[cat]) + # condition_child has no post (post_id is None) + + conditional = create_conditional( + condition=condition, + condition_child=condition_child, + question_yes=create_question( + question_type=Question.QuestionType.BINARY, + title="If Yes", + ), + question_no=create_question( + question_type=Question.QuestionType.BINARY, + title="If No", + ), + ) + + categories = get_conditional_categories(conditional) + assert len(categories) == 1 + assert categories[0].id == cat.id