diff --git a/core/users/tests/test_views.py b/core/users/tests/test_views.py index b28a18981..c55576c50 100755 --- a/core/users/tests/test_views.py +++ b/core/users/tests/test_views.py @@ -4,6 +4,7 @@ from django.contrib.auth.models import AnonymousUser from django.contrib.messages.middleware import MessageMiddleware from django.contrib.sessions.middleware import SessionMiddleware +from django.db.models import Prefetch from django.http import HttpRequest, HttpResponseRedirect from django.test import RequestFactory from django.urls import reverse @@ -11,7 +12,7 @@ from core.users.forms import UserAdminChangeForm from core.users.models import User from core.users.tests.factories import UserFactory -from core.users.views import UserRedirectView, UserUpdateView, user_detail_view +from core.users.views import CustomUserEditView, UserRedirectView, UserUpdateView, user_detail_view pytestmark = pytest.mark.django_db @@ -96,3 +97,36 @@ def test_not_authenticated(self, user: User, rf: RequestFactory): assert isinstance(response, HttpResponseRedirect) assert response.status_code == 302 assert response.url == f"{login_url}?next=/fake-url/" + + +class TestCustomUserEditView: + def test_get_queryset_prefetches_relations(self, user: User, rf: RequestFactory): + view = CustomUserEditView() + request = rf.get("/fake-url/") + request.user = user + view.request = request + + queryset = view.get_queryset() + + prefetch_lookups = queryset._prefetch_related_lookups + lookup_names = [] + for lookup in prefetch_lookups: + if isinstance(lookup, Prefetch): + lookup_names.append(lookup.prefetch_through) + else: + lookup_names.append(lookup) + + assert "journal" in lookup_names + assert "collection" in lookup_names + assert "groups" in lookup_names + assert "user_permissions" in lookup_names + + def test_get_queryset_returns_all_users(self, user: User, rf: RequestFactory): + view = CustomUserEditView() + request = rf.get("/fake-url/") + request.user = user + view.request = request + + queryset = view.get_queryset() + + assert user in queryset diff --git a/core/users/views.py b/core/users/views.py index b762b13d7..abe380218 100755 --- a/core/users/views.py +++ b/core/users/views.py @@ -7,11 +7,26 @@ from django.utils.translation import gettext_lazy as _ from django.views.generic import DetailView, RedirectView, UpdateView +from wagtail.users.views.users import EditView as WagtailUserEditView + from journal.models import Journal, SciELOJournal User = get_user_model() +class CustomUserEditView(WagtailUserEditView): + def get_queryset(self): + return User.objects.prefetch_related( + Prefetch( + "journal", + queryset=Journal.objects.select_related("official"), + ), + "collection", + "groups", + "user_permissions", + ) + + class UserDetailView(LoginRequiredMixin, DetailView): model = User slug_field = "username" diff --git a/core/users/viewsets.py b/core/users/viewsets.py index 8f265e1c8..5eedb702f 100644 --- a/core/users/viewsets.py +++ b/core/users/viewsets.py @@ -1,11 +1,13 @@ from wagtail.users.views.users import UserViewSet as WagtailUserViewSet from .forms import CustomUserCreationForm, CustomUserEditForm +from .views import CustomUserEditView class UserViewSet(WagtailUserViewSet): create_template_name = "wagtailusers/users/create.html" edit_template_name = "wagtailusers/users/edit.html" + edit_view_class = CustomUserEditView def get_form_class(self, for_update=False): if for_update: