diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/config/test_db_routing.py b/tests/config/test_db_routing.py new file mode 100644 index 00000000000..1239f438764 --- /dev/null +++ b/tests/config/test_db_routing.py @@ -0,0 +1,193 @@ +import logging +from unittest.mock import MagicMock + +import pytest +from django.db.utils import OperationalError +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +from treeherder.config.db_routing import ( + READ_REPLICA_APP_ALLOW_LIST, + ReadReplicaMixin, + ReadReplicaRouter, + _state, +) + + +@pytest.fixture(autouse=True) +def clear_state(): + """Ensure the thread-local is clean before and after every test.""" + if hasattr(_state, "use_replica"): + del _state.use_replica + yield + if hasattr(_state, "use_replica"): + del _state.use_replica + + +def _model(app_label): + m = MagicMock() + m._meta.app_label = app_label + return m + + +def test_db_for_read_returns_replica_when_state_set_and_app_allowed(): + _state.use_replica = True + router = ReadReplicaRouter() + assert router.db_for_read(_model("perf")) == "read_replica" + assert router.db_for_read(_model("model")) == "read_replica" + + +def test_db_for_read_returns_none_when_state_unset(): + router = ReadReplicaRouter() + assert router.db_for_read(_model("perf")) is None + + +def test_db_for_read_returns_none_for_unlisted_app_even_when_state_set(): + _state.use_replica = True + router = ReadReplicaRouter() + # Auth and sessions must always read from primary. + assert router.db_for_read(_model("auth")) is None + assert router.db_for_read(_model("sessions")) is None + assert router.db_for_read(_model("etl")) is None + + +def test_db_for_write_always_returns_none(): + _state.use_replica = True + router = ReadReplicaRouter() + assert router.db_for_write(_model("perf")) is None + assert router.db_for_write(_model("model")) is None + + +def test_allow_relation_true_between_default_and_replica(): + router = ReadReplicaRouter() + a, b = MagicMock(), MagicMock() + a._state.db = "default" + b._state.db = "read_replica" + assert router.allow_relation(a, b) is True + b._state.db = "default" + a._state.db = "read_replica" + assert router.allow_relation(a, b) is True + + +def test_allow_migrate_blocks_replica(): + router = ReadReplicaRouter() + assert router.allow_migrate("read_replica", "perf") is False + assert router.allow_migrate("default", "perf") is None + assert router.allow_migrate("default", "auth") is None + + +def test_allow_list_is_perf_and_model(): + assert READ_REPLICA_APP_ALLOW_LIST == {"perf", "model"} + + +class _RecordingView(ReadReplicaMixin, APIView): + """Test view that records the thread-local state at the moment it ran.""" + + # Disable auth/permission so CSRF does not interfere with mixin tests. + authentication_classes = [] + permission_classes = [] + + raise_on_call = None # set per-test + call_count = 0 + saw_use_replica = [] + + def get(self, request): + type(self).call_count += 1 + type(self).saw_use_replica.append(getattr(_state, "use_replica", False)) + if type(self).raise_on_call and type(self).call_count <= type(self).raise_on_call: + raise OperationalError("simulated replica failure") + return Response({"ok": True}) + + def post(self, request): + type(self).call_count += 1 + type(self).saw_use_replica.append(getattr(_state, "use_replica", False)) + return Response({"ok": True}) + + +@pytest.fixture +def reset_view(): + _RecordingView.raise_on_call = 0 + _RecordingView.call_count = 0 + _RecordingView.saw_use_replica = [] + yield + + +def test_mixin_flips_state_on_get(reset_view): + factory = APIRequestFactory() + view = _RecordingView.as_view() + response = view(factory.get("/x")) + assert response.status_code == 200 + assert _RecordingView.saw_use_replica == [True] + assert not hasattr(_state, "use_replica") # cleared after dispatch + + +def test_mixin_does_not_flip_state_on_post(reset_view): + factory = APIRequestFactory() + view = _RecordingView.as_view() + response = view(factory.post("/x", data={})) + assert response.status_code == 200 + assert _RecordingView.saw_use_replica == [False] + + +def test_mixin_clears_state_when_view_raises(reset_view): + _RecordingView.raise_on_call = 99 # always raise + factory = APIRequestFactory() + view = _RecordingView.as_view() + # The second dispatch (retry) also raises, so the OperationalError + # propagates. The important thing is that _state is cleared. + with pytest.raises(OperationalError): + view(factory.get("/x")) + assert not hasattr(_state, "use_replica") + + +def test_mixin_retries_once_on_operational_error(reset_view, caplog): + _RecordingView.raise_on_call = 1 # fail first call, succeed second + factory = APIRequestFactory() + view = _RecordingView.as_view() + with caplog.at_level(logging.WARNING): + response = view(factory.get("/x")) + assert response.status_code == 200 + assert _RecordingView.call_count == 2 + # First attempt had the flag, retry did not. + assert _RecordingView.saw_use_replica == [True, False] + # Fallback log emitted exactly once. + assert any("db_routing_fallback" in rec.message for rec in caplog.records) + + +def test_mixin_retries_only_once(reset_view): + _RecordingView.raise_on_call = 2 # fail twice + factory = APIRequestFactory() + view = _RecordingView.as_view() + with pytest.raises(OperationalError): + view(factory.get("/x")) + assert _RecordingView.call_count == 2 # original + 1 retry + + +def test_mixin_is_noop_when_replica_alias_not_configured(reset_view, caplog): + """When the kill switch is off (no replica alias), the mixin must not + flip the thread-local and must not emit fallback logs on primary errors. + """ + + from django.db import connections + + _RecordingView.raise_on_call = 99 # always raise — simulates primary error + factory = APIRequestFactory() + view = _RecordingView.as_view() + + # Temporarily drop the replica alias from the connection handler so the + # mixin sees the kill switch as off. + saved = connections.databases.pop("read_replica") + try: + with caplog.at_level(logging.WARNING): + with pytest.raises(OperationalError): + view(factory.get("/x")) + finally: + connections.databases["read_replica"] = saved + + # Mixin did not flip the thread-local (no replica to route to). + assert _RecordingView.saw_use_replica == [False] + # Mixin did not retry — one call only, not two. + assert _RecordingView.call_count == 1 + # No misleading fallback log. + assert not any("db_routing_fallback" in rec.message for rec in caplog.records) diff --git a/tests/settings.py b/tests/settings.py index 28101b45f19..cda88f6e77e 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,3 +1,4 @@ +import copy import os from treeherder.config.settings import * # noqa: F403 @@ -13,6 +14,13 @@ DATABASES["default"] = environ.Env.db_url_config(TEST_DATABASE_URL) # noqa: F405 DATABASES["default"]["TEST"] = {"NAME": "test_treeherder"} # noqa: F405 + +# Register a read_replica alias pointed at the same test DSN so integration +# tests can capture per-alias queries via CaptureQueriesContext. The router +# is installed unconditionally during tests so its behavior is exercised. +DATABASES["read_replica"] = copy.deepcopy(DATABASES["default"]) # noqa: F405 +DATABASE_ROUTERS = ["treeherder.config.db_routing.ReadReplicaRouter"] + KEY_PREFIX = "test" TREEHERDER_TEST_REPOSITORY_NAME = "mozilla-central" diff --git a/tests/webapp/api/test_perfcompare_api.py b/tests/webapp/api/test_perfcompare_api.py index 98b36d2b453..83387c323d8 100644 --- a/tests/webapp/api/test_perfcompare_api.py +++ b/tests/webapp/api/test_perfcompare_api.py @@ -8,7 +8,12 @@ from treeherder.perf.models import PerformanceDatum, PerformanceDatumReplicate from treeherder.webapp.api import perfcompare_utils -pytestmark = pytest.mark.perf +pytestmark = [ + pytest.mark.perf, + # transaction=True is required because the routed viewset reads through the + # separate ``read_replica`` connection, which only sees committed data. + pytest.mark.django_db(transaction=True, databases=["default", "read_replica"]), +] NOW = datetime.datetime.now() ONE_DAY_AGO = NOW - datetime.timedelta(days=1) diff --git a/tests/webapp/api/test_performance_bug_template_api.py b/tests/webapp/api/test_performance_bug_template_api.py index aa5d34ae701..5c03dd4641a 100644 --- a/tests/webapp/api/test_performance_bug_template_api.py +++ b/tests/webapp/api/test_performance_bug_template_api.py @@ -3,7 +3,12 @@ from treeherder.perf.models import PerformanceBugTemplate, PerformanceFramework -pytestmark = pytest.mark.perf +pytestmark = [ + pytest.mark.perf, + # transaction=True is required because the routed viewset reads through the + # separate ``read_replica`` connection, which only sees committed data. + pytest.mark.django_db(transaction=True, databases=["default", "read_replica"]), +] def test_perf_bug_template_api(client, test_perf_framework): diff --git a/tests/webapp/api/test_performance_data_api.py b/tests/webapp/api/test_performance_data_api.py index 90b13776147..0d707026480 100644 --- a/tests/webapp/api/test_performance_data_api.py +++ b/tests/webapp/api/test_performance_data_api.py @@ -15,7 +15,12 @@ ) from treeherder.webapp.api.performance_data import PerformanceSummary -pytestmark = pytest.mark.perf +pytestmark = [ + pytest.mark.perf, + # transaction=True is required because the routed viewset reads through the + # separate ``read_replica`` connection, which only sees committed data. + pytest.mark.django_db(transaction=True, databases=["default", "read_replica"]), +] NOW = datetime.datetime.now() ONE_DAY_AGO = NOW - datetime.timedelta(days=1) diff --git a/tests/webapp/api/test_read_replica_routing.py b/tests/webapp/api/test_read_replica_routing.py new file mode 100644 index 00000000000..039ffe38c16 --- /dev/null +++ b/tests/webapp/api/test_read_replica_routing.py @@ -0,0 +1,65 @@ +"""End-to-end test that ReadReplicaMixin routes a real GET to the replica alias. + +Both DB aliases point at the same physical Postgres in test settings, but they +are separate connections/sessions. We therefore use ``transaction=True`` so the +``default`` connection commits its fixture writes and the ``read_replica`` +connection (separate session) can see them. +""" + +import pytest +from django.db import connections +from django.test.utils import CaptureQueriesContext +from django.urls import reverse + +from treeherder.perf.models import PerformanceFramework + +pytestmark = pytest.mark.django_db(transaction=True, databases=["default", "read_replica"]) + + +def test_performance_framework_list_hits_replica(client): + PerformanceFramework.objects.create(name="talos", enabled=True) + PerformanceFramework.objects.create(name="awsy", enabled=True) + + with CaptureQueriesContext(connections["read_replica"]) as replica_ctx: + response = client.get(reverse("performance-frameworks-list")) + + assert response.status_code == 200 + assert len(response.json()) == 2 + assert len(replica_ctx.captured_queries) > 0, ( + "Expected PerformanceFramework reads to be routed to the replica alias" + ) + + +def test_performance_alert_summary_list_stays_on_default(client, test_perf_alert_summary): + """A viewset that is *not* opted in must not route to the replica.""" + with CaptureQueriesContext(connections["read_replica"]) as replica_ctx: + response = client.get(reverse("performance-alert-summaries-list")) + + assert response.status_code == 200 + assert len(replica_ctx.captured_queries) == 0, ( + "PerformanceAlertSummaryViewSet must remain on primary" + ) + + +def test_performance_signatures_list_hits_replica(client, test_perf_signature): + with CaptureQueriesContext(connections["read_replica"]) as replica_ctx: + response = client.get( + reverse( + "performance-signatures-list", + kwargs={"project": test_perf_signature.repository.name}, + ) + ) + + assert response.status_code == 200 + assert len(replica_ctx.captured_queries) > 0 + + +def test_performance_summary_hits_replica(client, test_perf_signature): + with CaptureQueriesContext(connections["read_replica"]) as replica_ctx: + response = client.get( + reverse("performance-summary") + + f"?repository={test_perf_signature.repository.name}&interval=86400" + ) + + assert response.status_code == 200 + assert len(replica_ctx.captured_queries) > 0 diff --git a/treeherder/config/db_routing.py b/treeherder/config/db_routing.py new file mode 100644 index 00000000000..94f2e7ab83c --- /dev/null +++ b/treeherder/config/db_routing.py @@ -0,0 +1,110 @@ +"""Read-only replica routing. + +A thread-local flag, set by :class:`ReadReplicaMixin`, opts a single request +into reading from the ``read_replica`` database alias. The router only routes +models whose Django app label is in :data:`READ_REPLICA_APP_ALLOW_LIST`. + +Design: .claude/plans/READ_REPLICA_DESIGN.md +""" + +from __future__ import annotations + +import logging +import threading + +from django.db import connections +from django.db.utils import InterfaceError, OperationalError + +# Apps whose models are eligible for replica reads when the thread-local is +# set. Allow-list (not deny-list) so opting new code in is an explicit choice. +READ_REPLICA_APP_ALLOW_LIST: frozenset[str] = frozenset({"perf", "model"}) + +_state = threading.local() + +logger = logging.getLogger(__name__) + +_SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) + + +class ReadReplicaRouter: + """Route reads to ``read_replica`` when the thread-local flag is set.""" + + def db_for_read(self, model, **hints): + if not getattr(_state, "use_replica", False): + return None + if model._meta.app_label not in READ_REPLICA_APP_ALLOW_LIST: + return None + return "read_replica" + + def db_for_write(self, model, **hints): + # Writes never go to the replica. + return None + + def allow_relation(self, obj1, obj2, **hints): + dbs = {obj1._state.db, obj2._state.db} + if dbs <= {"default", "read_replica"}: + return True + return None + + def allow_migrate(self, db, app_label, model_name=None, **hints): + # Never run migrations against the replica. + if db == "read_replica": + return False + return None + + +class ReadReplicaMixin: + """Opt a view's safe HTTP methods into reading from ``read_replica``. + + For GET/HEAD/OPTIONS, the thread-local flag is set before ``dispatch`` + runs and cleared in a ``finally``. On :class:`OperationalError` or + :class:`InterfaceError`, the request is retried once against the primary; + further failures propagate. + """ + + def handle_exception(self, exc): + # DRF's APIView.dispatch() wraps handler calls in a try/except that + # routes exceptions through handle_exception, which would swallow + # OperationalError/InterfaceError as a 500 response. Re-raise them + # here so the outer try/except in our dispatch() can trigger the + # fallback retry. + if isinstance(exc, OperationalError | InterfaceError): + raise exc + return super().handle_exception(exc) + + def dispatch(self, request, *args, **kwargs): + if request.method not in _SAFE_METHODS: + return super().dispatch(request, *args, **kwargs) + + # No-op when the replica alias is not configured (kill switch off). + # Without this guard the fallback log line below would fire on any + # primary failure, misleading on-call into investigating the replica. + if "read_replica" not in connections.databases: + return super().dispatch(request, *args, **kwargs) + + _state.use_replica = True + logger.debug( + "db_routing routed_to=read_replica path=%s method=%s", + request.path, + request.method, + ) + try: + try: + return super().dispatch(request, *args, **kwargs) + except (OperationalError, InterfaceError) as exc: + logger.warning( + "db_routing_fallback path=%s method=%s exception_type=%s", + request.path, + request.method, + type(exc).__name__, + ) + # Drop the (likely bad) replica connection before retrying. + try: + connections["read_replica"].close() + except Exception: + pass + _state.use_replica = False + return super().dispatch(request, *args, **kwargs) + finally: + if hasattr(_state, "use_replica"): + del _state.use_replica diff --git a/treeherder/config/settings.py b/treeherder/config/settings.py index 3f4de36fdaf..168704071dc 100644 --- a/treeherder/config/settings.py +++ b/treeherder/config/settings.py @@ -138,6 +138,16 @@ if UPSTREAM_DATABASE_URL: DATABASES["upstream"] = env.db_url_config(UPSTREAM_DATABASE_URL) +# Optional read-only replica for offloading specific GET endpoints from the +# primary. Both env vars must be set for routing to activate. See +# treeherder/config/db_routing.py and .claude/plans/READ_REPLICA_DESIGN.md. +READ_REPLICA_DATABASE_URL = env("READ_REPLICA_DATABASE_URL", default=None) +READ_REPLICA_ENABLED = env.bool("READ_REPLICA_ENABLED", default=False) + +if READ_REPLICA_DATABASE_URL and READ_REPLICA_ENABLED: + DATABASES["read_replica"] = env.db_url_config(READ_REPLICA_DATABASE_URL) + DATABASE_ROUTERS = ["treeherder.config.db_routing.ReadReplicaRouter"] + # We're intentionally not using django-environ's query string options feature, # since it hides configuration outside of the repository, plus could lead to # drift between environments. diff --git a/treeherder/webapp/api/performance_data.py b/treeherder/webapp/api/performance_data.py index be7b1038535..ac23d1249af 100644 --- a/treeherder/webapp/api/performance_data.py +++ b/treeherder/webapp/api/performance_data.py @@ -18,6 +18,7 @@ from rest_framework.response import Response from rest_framework.status import HTTP_400_BAD_REQUEST +from treeherder.config.db_routing import ReadReplicaMixin from treeherder.etl.common import to_timestamp from treeherder.model import models from treeherder.perf import stats @@ -60,7 +61,7 @@ logger = logging.getLogger(__name__) -class PerformanceSignatureViewSet(viewsets.ViewSet): +class PerformanceSignatureViewSet(ReadReplicaMixin, viewsets.ViewSet): def list(self, request, project): repository = models.Repository.objects.get(name=project) @@ -194,7 +195,7 @@ def list(self, request, project): return Response(signature_map) -class PerformancePlatformViewSet(viewsets.ViewSet): +class PerformancePlatformViewSet(ReadReplicaMixin, viewsets.ViewSet): """ All platforms for a particular branch that have performance data """ @@ -216,14 +217,14 @@ def list(self, request, project): return Response(signature_data.values_list("platform__platform", flat=True).distinct()) -class PerformanceFrameworkViewSet(viewsets.ReadOnlyModelViewSet): +class PerformanceFrameworkViewSet(ReadReplicaMixin, viewsets.ReadOnlyModelViewSet): queryset = PerformanceFramework.objects.filter(enabled=True) serializer_class = PerformanceFrameworkSerializer filter_backends = [filters.OrderingFilter] ordering = "id" -class PerformanceJobViewSet(viewsets.ReadOnlyModelViewSet): +class PerformanceJobViewSet(ReadReplicaMixin, viewsets.ReadOnlyModelViewSet): def list(self, request, project): repository = models.Repository.objects.get(name=project) # Expect exactly one job_id in query params @@ -283,7 +284,7 @@ def list(self, request, project): return Response(results) -class PerformanceDatumViewSet(viewsets.ViewSet): +class PerformanceDatumViewSet(ReadReplicaMixin, viewsets.ViewSet): """ This view serves performance test result data """ @@ -792,21 +793,21 @@ def nudge(self, alert, new_push_id, new_prev_push_id): raise exceptions.APIException("Nudging has been disabled", 400) -class PerformanceBugTemplateViewSet(viewsets.ReadOnlyModelViewSet): +class PerformanceBugTemplateViewSet(ReadReplicaMixin, viewsets.ReadOnlyModelViewSet): queryset = PerformanceBugTemplate.objects.all() serializer_class = PerformanceBugTemplateSerializer filter_backends = (django_filters.rest_framework.DjangoFilterBackend, filters.OrderingFilter) filterset_fields = ["framework"] -class PerformanceIssueTrackerViewSet(viewsets.ReadOnlyModelViewSet): +class PerformanceIssueTrackerViewSet(ReadReplicaMixin, viewsets.ReadOnlyModelViewSet): queryset = IssueTracker.objects.all() serializer_class = IssueTrackerSerializer filter_backends = [filters.OrderingFilter] ordering = "id" -class PerformanceSummary(generics.ListAPIView): +class PerformanceSummary(ReadReplicaMixin, generics.ListAPIView): serializer_class = PerformanceSummarySerializer queryset = None @@ -1054,7 +1055,7 @@ def _filter_out_retriggers(serialized_data): return serialized_data -class PerformanceAlertSummaryTasks(generics.ListAPIView): +class PerformanceAlertSummaryTasks(ReadReplicaMixin, generics.ListAPIView): serializer_class = PerformanceAlertSummaryTasksSerializer queryset = None @@ -1080,7 +1081,7 @@ def list(self, request): return Response(data=serializer.data) -class PerfCompareResults(generics.ListAPIView): +class PerfCompareResults(ReadReplicaMixin, generics.ListAPIView): serializer_class = PerfCompareResultsSerializer queryset = None @@ -1976,7 +1977,7 @@ def _process_stats( return stats_data -class TestSuiteHealthViewSet(viewsets.ViewSet): +class TestSuiteHealthViewSet(ReadReplicaMixin, viewsets.ViewSet): def list(self, request): query_params = TestSuiteHealthParamsSerializer(data=request.query_params) if not query_params.is_valid():