From ad5a5a2344cacc74c14d263c7ba6d83cdf00e1c5 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 24 Apr 2026 04:56:43 -0700 Subject: [PATCH 1/7] Move more globals to RegistrationContext contextual management of Config, decorated pages, and memo components --- .../src/reflex_base/components/component.py | 9 +- .../src/reflex_base/components/dynamic.py | 22 +-- .../reflex-base/src/reflex_base/config.py | 89 +++++---- .../reflex-base/src/reflex_base/registry.py | 48 ++++- pyi_hashes.json | 2 +- reflex/app.py | 15 +- reflex/compiler/compiler.py | 5 +- reflex/experimental/memo.py | 16 +- reflex/page.py | 9 +- reflex/reflex.py | 8 +- reflex/testing.py | 13 +- reflex/utils/prerequisites.py | 5 +- reflex/utils/templates.py | 4 +- tests/units/components/test_component.py | 4 +- tests/units/conftest.py | 20 -- tests/units/experimental/test_memo.py | 66 ++++--- tests/units/reflex_base/test_registry.py | 178 ++++++++++++++++++ tests/units/test_app.py | 6 + tests/units/test_page.py | 39 ++-- tests/units/test_state.py | 10 +- tests/units/test_testing.py | 54 ++++-- 21 files changed, 439 insertions(+), 183 deletions(-) diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index 8f47f447c7d..f88d488206f 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -2236,9 +2236,6 @@ def _get_all_app_wrap_components( return self.get_component()._get_all_app_wrap_components(ignore_ids=ignore_ids) -CUSTOM_COMPONENTS: dict[str, CustomComponent] = {} - - def _register_custom_component( component_fn: Callable[..., Component], ): @@ -2253,6 +2250,8 @@ def _register_custom_component( Raises: TypeError: If the tag name cannot be determined. """ + from reflex_base.registry import RegistrationContext + dummy_props = { prop: ( Var( @@ -2273,7 +2272,9 @@ def _register_custom_component( if dummy_component.tag is None: msg = f"Could not determine the tag name for {component_fn!r}" raise TypeError(msg) - CUSTOM_COMPONENTS[dummy_component.tag] = dummy_component + RegistrationContext.ensure_context().custom_components[dummy_component.tag] = ( + dummy_component + ) return dummy_component diff --git a/packages/reflex-base/src/reflex_base/components/dynamic.py b/packages/reflex-base/src/reflex_base/components/dynamic.py index 6c2100a40e8..fa709d5db47 100644 --- a/packages/reflex-base/src/reflex_base/components/dynamic.py +++ b/packages/reflex-base/src/reflex_base/components/dynamic.py @@ -26,16 +26,6 @@ def get_cdn_url(lib: str) -> str: return f"https://cdn.jsdelivr.net/npm/{lib}" + "/+esm" -bundled_libraries = [ - "react", - "@radix-ui/themes", - "@emotion/react", - f"$/{constants.Dirs.UTILS}/context", - f"$/{constants.Dirs.UTILS}/state", - f"$/{constants.Dirs.UTILS}/components", -] - - def bundle_library(component: Union["Component", str]): """Bundle a library with the component. @@ -45,13 +35,16 @@ def bundle_library(component: Union["Component", str]): Raises: DynamicComponentMissingLibraryError: Raised when a dynamic component is missing a library. """ + from reflex_base.registry import RegistrationContext + + bundled = RegistrationContext.ensure_context().bundled_libraries if isinstance(component, str): - bundled_libraries.append(component) + bundled.append(component) return if component.library is None: msg = "Component must have a library to bundle." raise DynamicComponentMissingLibraryError(msg) - bundled_libraries.append(format_library_name(component.library)) + bundled.append(format_library_name(component.library)) def load_dynamic_serializer(): @@ -73,6 +66,9 @@ def make_component(component: Component) -> str: from reflex_components_core.base.bare import Bare from reflex.compiler import compiler, templates, utils + from reflex_base.registry import RegistrationContext + + libs_in_window = RegistrationContext.ensure_context().bundled_libraries component = Bare.create(Var.create(component)) @@ -93,8 +89,6 @@ def make_component(component: Component) -> str: ) ] = None - libs_in_window = bundled_libraries - component_imports = component._get_all_imports() compiler._apply_common_imports(component_imports) diff --git a/packages/reflex-base/src/reflex_base/config.py b/packages/reflex-base/src/reflex_base/config.py index 9a522628e6a..a7923afc76a 100644 --- a/packages/reflex-base/src/reflex_base/config.py +++ b/packages/reflex-base/src/reflex_base/config.py @@ -4,7 +4,6 @@ import importlib import os import sys -import threading import urllib.parse from collections.abc import Sequence from importlib.util import find_spec @@ -639,7 +638,7 @@ def _set_persistent(self, **kwargs): def _get_config() -> Config: - """Get the app config. + """Import rxconfig.py fresh and return its config object. Returns: The app config. @@ -651,48 +650,68 @@ def _get_config() -> Config: # we need this condition to ensure that a ModuleNotFound error is not thrown when # running unit/integration tests or during `reflex init`. return Config(app_name="", _skip_plugins_checks=True) + # Never cache rxconfig — each load goes to disk so different + # RegistrationContexts can hold independent Config instances. + sys.modules.pop(constants.Config.MODULE, None) rxconfig = importlib.import_module(constants.Config.MODULE) return rxconfig.config -# Protect sys.path from concurrent modification -_config_lock = threading.RLock() +def _load_config() -> Config: + """Load the config from rxconfig.py with cwd on sys.path. + Returns: + The app config. + """ + orig_sys_path = sys.path.copy() + sys.path.clear() + sys.path.append(str(Path.cwd())) + try: + return _get_config() + except Exception: + sys.path.extend(orig_sys_path) + return _get_config() + finally: + extra_paths = [ + p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) + ] + sys.path.clear() + sys.path.extend(extra_paths) + sys.path.extend(orig_sys_path) -def get_config(reload: bool = False) -> Config: - """Get the app config. - Args: - reload: Re-import the rxconfig module from disk +def get_config() -> Config: + """Get the app config from the current RegistrationContext. + + The config is loaded from rxconfig.py once per RegistrationContext and + cached on the context thereafter. If no context is currently attached, + one is created and attached automatically. Returns: The app config. """ - cached_rxconfig = sys.modules.get(constants.Config.MODULE, None) - if cached_rxconfig is not None: - if reload: - # Remove any cached module when `reload` is requested. - del sys.modules[constants.Config.MODULE] - else: - return cached_rxconfig.config + from reflex_base.registry import RegistrationContext - with _config_lock: - orig_sys_path = sys.path.copy() - sys.path.clear() - sys.path.append(str(Path.cwd())) - try: - # Try to import the module with only the current directory in the path. - return _get_config() - except Exception: - # If the module import fails, try to import with the original sys.path. - sys.path.extend(orig_sys_path) - return _get_config() - finally: - # Find any entries added to sys.path by rxconfig.py itself. - extra_paths = [ - p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) - ] - # Restore the original sys.path. - sys.path.clear() - sys.path.extend(extra_paths) - sys.path.extend(orig_sys_path) + ctx = RegistrationContext.ensure_context() + config = ctx.config + if config is None: + config = _load_config() + ctx._set_config(config) + return config + + +def reload_config() -> Config: + """Force a fresh load of the config into the current RegistrationContext. + + Clears any cached config on the current context and reloads rxconfig.py + from disk. + + Returns: + The freshly loaded app config. + """ + from reflex_base.registry import RegistrationContext + + ctx = RegistrationContext.ensure_context() + config = _load_config() + ctx._set_config(config) + return config diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py index 8caa1d2b2c3..fb7b4a089d2 100644 --- a/packages/reflex-base/src/reflex_base/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from typing_extensions import Self @@ -11,11 +11,32 @@ from reflex_base.utils.exceptions import StateValueError if TYPE_CHECKING: + from collections.abc import Callable + from reflex.state import BaseState from reflex_base.components.component import StatefulComponent + from reflex_base.config import Config from reflex_base.event import EventHandler +def _default_bundled_libraries() -> list[str]: + """Return the initial set of bundled libraries for a new context. + + Returns: + The default list of libraries bundled into every app build. + """ + from reflex_base import constants + + return [ + "react", + "@radix-ui/themes", + "@emotion/react", + f"$/{constants.Dirs.UTILS}/context", + f"$/{constants.Dirs.UTILS}/state", + f"$/{constants.Dirs.UTILS}/components", + ] + + @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) class RegisteredEventHandler: """A registered event handler, which includes the handler and its full name.""" @@ -44,6 +65,31 @@ class RegistrationContext(BaseContext): default_factory=dict, repr=False, ) + config: Config | None = dataclasses.field(default=None, repr=False) + decorated_pages: list[tuple[Callable, dict[str, Any]]] = dataclasses.field( + default_factory=list, + repr=False, + ) + custom_components: dict[str, Any] = dataclasses.field( + default_factory=dict, + repr=False, + ) + memo_definitions: dict[str, Any] = dataclasses.field( + default_factory=dict, + repr=False, + ) + bundled_libraries: list[str] = dataclasses.field( + default_factory=_default_bundled_libraries, + repr=False, + ) + + def _set_config(self, config: Config) -> None: + """Set the config for this context. + + Args: + config: The config to associate with this context. + """ + object.__setattr__(self, "config", config) @classmethod def ensure_context(cls) -> Self: diff --git a/pyi_hashes.json b/pyi_hashes.json index ba57aa5a477..4e5b8af7a20 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -120,5 +120,5 @@ "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7", "reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", - "reflex/experimental/memo.pyi": "2c119a0dfea362dcd8193786363cbc02" + "reflex/experimental/memo.pyi": "69479b0af7bd47b642337d116a19b1f8" } diff --git a/reflex/app.py b/reflex/app.py index 240da421ef5..c2b2c79cef8 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -26,12 +26,11 @@ from reflex_base import constants from reflex_base.components.component import ( - CUSTOM_COMPONENTS, Component, ComponentStyle, evaluate_style_namespaces, ) -from reflex_base.config import get_config +from reflex_base.config import get_config, reload_config from reflex_base.context.base import BaseContext from reflex_base.environment import ExecutorType, environment from reflex_base.event import ( @@ -82,10 +81,8 @@ compile_theme, readable_name_from_component, ) -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.manager.token import BaseStateToken -from reflex.page import DECORATED_PAGES from reflex.route import ( get_route_args, replace_brackets_with_keywords, @@ -455,7 +452,7 @@ def __post_init__(self): msg = "rx.BaseState cannot be subclassed directly. Use rx.State instead" raise ValueError(msg) - get_config(reload=True) + reload_config() if "breakpoints" in self.style: set_breakpoints(self.style.pop("breakpoints")) @@ -1157,8 +1154,7 @@ def memoized_badge(): def _apply_decorated_pages(self): """Add @rx.page decorated pages to the app.""" - app_name = get_config().app_name - for render, kwargs in DECORATED_PAGES[app_name]: + for render, kwargs in RegistrationContext.ensure_context().decorated_pages: self.add_page(render, **kwargs) def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> None: @@ -1371,13 +1367,14 @@ def memoized_toast_provider(): app_wrappers[key] = component # Compile custom components. + ctx = RegistrationContext.ensure_context() ( memo_components_output, memo_components_result, memo_components_imports, ) = compiler.compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), - tuple(EXPERIMENTAL_MEMOS.values()), + dict.fromkeys(ctx.custom_components.values()), + tuple(ctx.memo_definitions.values()), ) compile_results.append((memo_components_output, memo_components_result)) all_imports.update(memo_components_imports) diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index feec49ee69a..a05e8acc04e 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -88,10 +88,11 @@ def _compile_app(app_root: Component) -> str: Returns: The compiled app. """ - from reflex_base.components.dynamic import bundled_libraries + from reflex_base.registry import RegistrationContext window_libraries = [ - (_normalize_library_name(name), name) for name in bundled_libraries + (_normalize_library_name(name), name) + for name in RegistrationContext.ensure_context().bundled_libraries ] window_libraries_deduped = list(dict.fromkeys(window_libraries)) diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 7dee0c72eea..8b583829f9d 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -10,7 +10,6 @@ from reflex_base import constants from reflex_base.components.component import Component -from reflex_base.components.dynamic import bundled_libraries from reflex_base.constants.compiler import SpecialAttributes from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_base.utils import format @@ -138,9 +137,6 @@ def _get_experimental_memo_component_class( ) -EXPERIMENTAL_MEMOS: dict[str, ExperimentalMemoDefinition] = {} - - def _memo_registry_key(definition: ExperimentalMemoDefinition) -> str: """Get the registry key for an experimental memo. @@ -185,8 +181,11 @@ def _register_memo_definition(definition: ExperimentalMemoDefinition) -> None: Raises: ValueError: If another memo already compiles to the same exported name. """ + from reflex_base.registry import RegistrationContext + + memos = RegistrationContext.ensure_context().memo_definitions key = _memo_registry_key(definition) - if (existing := EXPERIMENTAL_MEMOS.get(key)) is not None and ( + if (existing := memos.get(key)) is not None and ( not _is_memo_reregistration(existing, definition) ): msg = ( @@ -197,7 +196,7 @@ def _register_memo_definition(definition: ExperimentalMemoDefinition) -> None: ) raise ValueError(msg) - EXPERIMENTAL_MEMOS[key] = definition + memos[key] = definition def _annotation_inner_type(annotation: Any) -> Any: @@ -355,12 +354,15 @@ def _validate_var_return_expr(return_expr: Var, func_name: str) -> None: ) raise TypeError(msg) + from reflex_base.registry import RegistrationContext + + bundled = RegistrationContext.ensure_context().bundled_libraries for lib in dict(var_data.imports): if not lib: continue if lib.startswith((".", "/", "$/", "http")): continue - if format.format_library_name(lib) in bundled_libraries: + if format.format_library_name(lib) in bundled: continue msg = ( f"Var-returning `@rx._x.memo` `{func_name}` cannot import `{lib}` because " diff --git a/reflex/page.py b/reflex/page.py index 2e3f8fc613e..637073ea560 100644 --- a/reflex/page.py +++ b/reflex/page.py @@ -3,7 +3,6 @@ from __future__ import annotations import sys -from collections import defaultdict from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -12,8 +11,6 @@ from reflex_base.event import EventType -DECORATED_PAGES: dict[str, list[tuple[Callable, dict[str, Any]]]] = defaultdict(list) - def page( route: str | None = None, @@ -45,7 +42,7 @@ def page( Returns: The decorated function. """ - from reflex_base.config import get_config + from reflex_base.registry import RegistrationContext def decorator(render_fn: Callable): kwargs: dict[str, Any] = {} @@ -64,7 +61,7 @@ def decorator(render_fn: Callable): if on_load: kwargs["on_load"] = on_load - DECORATED_PAGES[get_config().app_name].append((render_fn, kwargs)) + RegistrationContext.ensure_context().decorated_pages.append((render_fn, kwargs)) return render_fn @@ -74,8 +71,6 @@ def decorator(render_fn: Callable): class PageNamespace: """A namespace for page names.""" - DECORATED_PAGES = DECORATED_PAGES - def __new__( cls, route: str | None = None, diff --git a/reflex/reflex.py b/reflex/reflex.py index bf8a66afecd..f0de67072ee 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -8,7 +8,7 @@ import click from reflex_base import constants -from reflex_base.config import get_config +from reflex_base.config import get_config, reload_config from reflex_base.environment import environment from reflex_base.utils import console from reflex_cli.v2.deployments import hosting_cli @@ -303,7 +303,7 @@ def _run( config._set_persistent(backend_port=backend_port) # Reload the config to make sure the env vars are persistent. - get_config(reload=True) + reload_config() console.rule("[bold]Starting Reflex App") @@ -481,7 +481,7 @@ def compile(dry: bool, rich: bool): # Check the app. if prerequisites.needs_reinit(): _init(name=get_config().app_name) - get_config(reload=True) + reload_config() starting_time = time.monotonic() prerequisites.get_compiled_app(dry_run=dry, use_rich=rich) elapsed_time = time.monotonic() - starting_time @@ -896,6 +896,8 @@ def rename(new_name: str): from reflex.utils.rename import rename_app prerequisites.validate_app_name(new_name) + # Reload so we read rxconfig.py from the current directory, not a cached one. + reload_config() rename_app(new_name, get_config().loglevel) diff --git a/reflex/testing.py b/reflex/testing.py index 9a3ff023049..57239aa8e69 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -28,8 +28,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar import uvicorn -from reflex_base.components.component import CUSTOM_COMPONENTS, CustomComponent -from reflex_base.config import get_config +from reflex_base.components.component import CustomComponent +from reflex_base.config import get_config, reload_config from reflex_base.environment import environment from reflex_base.registry import RegistrationContext from reflex_base.utils.types import ASGIApp @@ -41,7 +41,6 @@ import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.shared import SharedState as SharedState # To register it. from reflex.state import reload_state_module from reflex.utils import console, js_runtimes @@ -240,10 +239,8 @@ def _get_source_from_app_source(self, app_source: Any) -> str: def _initialize_app(self): # disable telemetry reporting for tests os.environ["REFLEX_TELEMETRY_ENABLED"] = "false" - # Reset global memo registries so previous AppHarness apps do not - # leak compiled component definitions into the next test app. - CUSTOM_COMPONENTS.clear() - EXPERIMENTAL_MEMOS.clear() + # Memo/custom-component registries live on the new RegistrationContext + # below, so previous AppHarness apps cannot leak definitions here. CustomComponent.create().get_component.cache_clear() self.app_path.mkdir(parents=True, exist_ok=True) if self.app_source is not None: @@ -278,7 +275,7 @@ def _initialize_app(self): new_registration_context = deepcopy(AppHarness._base_registration_context) self._registry_token = RegistrationContext.set(new_registration_context) # ensure config and app are reloaded when testing different app - config = get_config(reload=True) + config = reload_config() # Ensure the AppHarness test does not skip State assignment due to running via pytest os.environ.pop(reflex.constants.PYTEST_CURRENT_TEST, None) os.environ[reflex.constants.APP_HARNESS_FLAG] = "true" diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 06662e53aa0..eebc310f1b7 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -199,13 +199,14 @@ def get_app(reload: bool = False) -> ModuleType: else config.app_module ) if reload: - from reflex.page import DECORATED_PAGES + from reflex_base.registry import RegistrationContext + from reflex.state import reload_state_module # Reset rx.State subclasses to avoid conflict when reloading. reload_state_module(module=module) - DECORATED_PAGES.clear() + RegistrationContext.ensure_context().decorated_pages.clear() # Reload the app module. importlib.reload(app) diff --git a/reflex/utils/templates.py b/reflex/utils/templates.py index 2751f9e21a4..43643e17273 100644 --- a/reflex/utils/templates.py +++ b/reflex/utils/templates.py @@ -8,7 +8,7 @@ from urllib.parse import urlparse from reflex_base import constants -from reflex_base.config import get_config +from reflex_base.config import reload_config from reflex.utils import console, net, path_ops, redir from reflex.utils.rename import rename_imports_and_app_name @@ -170,7 +170,7 @@ def create_config_init_app_from_remote_template(app_name: str, template_url: str # Move the rxconfig file here first. path_ops.mv(str(template_dir / constants.Config.FILE), constants.Config.FILE) - new_config = get_config(reload=True) + new_config = reload_config() # Get the template app's name from rxconfig in case it is different than # the source code repo name on github. diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index de397442311..b15ac526ab0 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -4,7 +4,6 @@ import pytest from reflex_base.components.component import ( - CUSTOM_COMPONENTS, Component, CustomComponent, StatefulComponent, @@ -21,6 +20,7 @@ parse_args_spec, passthrough_event_spec, ) +from reflex_base.registry import RegistrationContext from reflex_base.style import Style from reflex_base.utils.exceptions import ( ChildrenTypeError, @@ -874,7 +874,7 @@ def test_create_custom_component(my_component): component = rx.memo(my_component)(prop1="test", prop2=1) assert component.tag == "MyComponent" assert set(component.get_props()) == {"prop1", "prop2"} - assert component.tag in CUSTOM_COMPONENTS + assert component.tag in RegistrationContext.get().custom_components def test_custom_component_hash(my_component): diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 36baee0ec8e..7f50406dbd4 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -10,14 +10,12 @@ import pytest import pytest_asyncio -from reflex_base.components.component import CUSTOM_COMPONENTS from reflex_base.event import Event, EventSpec from reflex_base.event.context import EventContext from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor from reflex_base.registry import RegistrationContext from reflex.app import App -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory @@ -487,21 +485,3 @@ def clean_registration_context() -> Generator[RegistrationContext, None, None]: """ with RegistrationContext() as ctx: yield ctx - - -@pytest.fixture -def preserve_memo_registries(): - """Save and restore global memo registries around a test. - - Yields: - None - """ - custom_components = dict(CUSTOM_COMPONENTS) - experimental_memos = dict(EXPERIMENTAL_MEMOS) - try: - yield - finally: - CUSTOM_COMPONENTS.clear() - CUSTOM_COMPONENTS.update(custom_components) - EXPERIMENTAL_MEMOS.clear() - EXPERIMENTAL_MEMOS.update(experimental_memos) diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index f202ecf05d8..7988d8c1c00 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -6,7 +6,8 @@ from typing import Any import pytest -from reflex_base.components.component import CUSTOM_COMPONENTS, Component +from reflex_base.components.component import Component +from reflex_base.registry import RegistrationContext from reflex_base.style import Style from reflex_base.utils.imports import ImportVar from reflex_base.vars import VarData @@ -17,7 +18,6 @@ from reflex.compiler import compiler from reflex.compiler import utils as compiler_utils from reflex.experimental.memo import ( - EXPERIMENTAL_MEMOS, ExperimentalMemoComponent, ExperimentalMemoComponentDefinition, ExperimentalMemoFunctionDefinition, @@ -25,8 +25,30 @@ @pytest.fixture(autouse=True) -def _restore_memo_registries(preserve_memo_registries): - """Autouse wrapper around the shared preserve_memo_registries fixture.""" +def _fresh_registration_context(clean_registration_context: RegistrationContext): + """Isolate each test behind a fresh RegistrationContext. + + Args: + clean_registration_context: A fresh registration context fixture. + """ + + +def _memos() -> dict: + """Get the current context's memo definitions. + + Returns: + The memo_definitions dict on the current RegistrationContext. + """ + return RegistrationContext.get().memo_definitions + + +def _custom_components() -> dict: + """Get the current context's custom components. + + Returns: + The custom_components dict on the current RegistrationContext. + """ + return RegistrationContext.get().custom_components def test_var_returning_memo(): @@ -49,7 +71,7 @@ def format_price(amount: rx.Var[int], currency: rx.Var[str]) -> rx.Var[str]: ) assert isinstance(format_price._as_var(), FunctionVar) - definition = EXPERIMENTAL_MEMOS["format_price"] + definition = _memos()["format_price"] assert isinstance(definition, ExperimentalMemoFunctionDefinition) assert ( str(definition.function) == '((amount, currency) => ((currency+": $")+amount))' @@ -97,13 +119,11 @@ def my_card( assert 'foo:"extra"' in rendered["props"] assert 'className:"extra"' in rendered["props"] - definition = EXPERIMENTAL_MEMOS["MyCard"] + definition = _memos()["MyCard"] assert isinstance(definition, ExperimentalMemoComponentDefinition) assert any(str(prop) == "rest" for prop in definition.component.special_props) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + _, code, _ = compiler.compile_memo_components((), tuple(_memos().values())) assert "export const MyCard = memo(({children, title:title" in code assert "...rest" in code assert "jsx(RadixThemesBox,{...rest}" in code @@ -120,15 +140,13 @@ def conditional_slot( ) -> rx.Var[rx.Component]: return rx.cond(show, first, second) - definition = EXPERIMENTAL_MEMOS["ConditionalSlot"] + definition = _memos()["ConditionalSlot"] assert isinstance(definition, ExperimentalMemoComponentDefinition) assert definition.component.render() == { "contents": "(showRxMemo ? firstRxMemo : secondRxMemo)" } - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + _, code, _ = compiler.compile_memo_components((), tuple(_memos().values())) assert "export const ConditionalSlot = memo(({show:showRxMemo" in code assert "(showRxMemo ? firstRxMemo : secondRxMemo)" in code @@ -151,9 +169,7 @@ def merge_styles( assert '["color"] : "red"' in str(merged) assert '["className"] : "primary"' in str(merged) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + _, code, _ = compiler.compile_memo_components((), tuple(_memos().values())) assert ( "export const merge_styles = (({base, ...overrides}) => ({...base, ...overrides}));" in code @@ -185,9 +201,7 @@ def label_slot( assert '["children"]' in str(rendered) assert '["className"] : "slot"' in str(rendered) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + _, code, _ = compiler.compile_memo_components((), tuple(_memos().values())) assert "export const label_slot = (({children, label, ...rest}) => label);" in code @@ -234,7 +248,7 @@ def test_memo_rejects_component_and_function_name_collision(): def foo_bar() -> rx.Component: return rx.box() - assert "FooBar" in EXPERIMENTAL_MEMOS + assert "FooBar" in _memos() with pytest.raises(ValueError, match=r"name collision.*FooBar"): @@ -357,8 +371,8 @@ def my_card(children: rx.Var[rx.Component], *, title: rx.Var[str]) -> rx.Compone return rx.box(rx.heading(title), children) _, code, _ = compiler.compile_memo_components( - dict.fromkeys(CUSTOM_COMPONENTS.values()), - tuple(EXPERIMENTAL_MEMOS.values()), + dict.fromkeys(_custom_components().values()), + tuple(_memos().values()), ) assert "export const OldWrapper = memo(" in code @@ -381,7 +395,7 @@ def wrapper() -> rx.Component: assert "inner" not in experimental_component._get_all_imports() - definition = EXPERIMENTAL_MEMOS["Wrapper"] + definition = _memos()["Wrapper"] assert isinstance(definition, ExperimentalMemoComponentDefinition) _, imports = compiler_utils.compile_experimental_component_memo(definition) assert "inner" in imports @@ -396,7 +410,7 @@ def test_compile_experimental_component_memo_does_not_mutate_definition( def wrapper() -> rx.Component: return rx.box("hi") - definition = EXPERIMENTAL_MEMOS["Wrapper"] + definition = _memos()["Wrapper"] assert isinstance(definition, ExperimentalMemoComponentDefinition) assert definition.component.style == Style() @@ -428,8 +442,6 @@ def add_custom_code(self) -> list[str]: def foo_component(label: rx.Var[str]) -> rx.Component: return FooComponent.create(label, rx.Var("foo")) - _, code, _ = compiler.compile_memo_components( - (), tuple(EXPERIMENTAL_MEMOS.values()) - ) + _, code, _ = compiler.compile_memo_components((), tuple(_memos().values())) assert "const foo = 'bar'" in code diff --git a/tests/units/reflex_base/test_registry.py b/tests/units/reflex_base/test_registry.py index 474acf874c8..fced0d35d4e 100644 --- a/tests/units/reflex_base/test_registry.py +++ b/tests/units/reflex_base/test_registry.py @@ -1,9 +1,14 @@ """Tests for RegistrationContext.""" +from textwrap import dedent + import pytest +from reflex_base.config import Config, get_config, reload_config from reflex_base.registry import RegisteredEventHandler, RegistrationContext from reflex_base.utils.exceptions import StateValueError +from reflex.testing import chdir + def test_ensure_context_creates_if_missing(): """ensure_context() returns existing context or creates a new one.""" @@ -131,3 +136,176 @@ async def _tmp(): handler = EventHandler(fn=_tmp) RegistrationContext.register_event_handler(handler) assert len(forked_registration_context.event_handlers) > 0 + + +def test_clean_context_has_no_config(clean_registration_context: RegistrationContext): + """A fresh RegistrationContext starts with config=None. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + assert clean_registration_context.config is None + + +def _write_rxconfig(path, app_name: str) -> None: + (path / "rxconfig.py").write_text( + dedent( + f""" + import reflex as rx + config = rx.Config(app_name="{app_name}") + """ + ) + ) + + +def test_get_config_caches_on_context( + tmp_path, clean_registration_context: RegistrationContext +): + """get_config loads rxconfig once and caches the result on the context. + + Args: + tmp_path: Pytest tmp dir fixture. + clean_registration_context: A fresh, empty registration context. + """ + _write_rxconfig(tmp_path, "ctx_app") + with chdir(tmp_path): + assert clean_registration_context.config is None + first = get_config() + assert first is clean_registration_context.config + assert first.app_name == "ctx_app" + assert get_config() is first + + +def test_reload_config_forces_fresh_load( + tmp_path, clean_registration_context: RegistrationContext +): + """reload_config re-reads rxconfig.py and replaces the cached instance. + + Args: + tmp_path: Pytest tmp dir fixture. + clean_registration_context: A fresh, empty registration context. + """ + _write_rxconfig(tmp_path, "before") + with chdir(tmp_path): + first = get_config() + assert first.app_name == "before" + + _write_rxconfig(tmp_path, "after") + second = reload_config() + assert second is not first + assert second.app_name == "after" + assert clean_registration_context.config is second + assert get_config() is second + + +def test_two_contexts_hold_independent_configs(tmp_path): + """Different RegistrationContexts can cache different configs in one process. + + Args: + tmp_path: Pytest tmp dir fixture. + """ + app_a = tmp_path / "app_a" + app_a.mkdir() + _write_rxconfig(app_a, "app_a") + + app_b = tmp_path / "app_b" + app_b.mkdir() + _write_rxconfig(app_b, "app_b") + + with RegistrationContext() as ctx_a, chdir(app_a): + config_a = get_config() + + with RegistrationContext() as ctx_b, chdir(app_b): + config_b = get_config() + + assert config_a.app_name == "app_a" + assert config_b.app_name == "app_b" + assert config_a is not config_b + assert ctx_a.config is config_a + assert ctx_b.config is config_b + + +def test_get_config_outside_context_auto_attaches(): + """Calling get_config with no active context attaches one automatically.""" + import contextvars + + def _run() -> Config: + with pytest.raises(LookupError): + RegistrationContext.get() + cfg = get_config() + assert RegistrationContext.get().config is cfg + return cfg + + # Run in a fresh Context so the ContextVar starts unset. + config = contextvars.Context().run(_run) + assert isinstance(config, Config) + + +def test_decorated_pages_isolated_between_contexts(): + """@page registrations in one context do not leak to another.""" + from reflex.page import page + + def a_page(): + return None + + def b_page(): + return None + + with RegistrationContext() as ctx_a: + page(route="/a")(a_page) + assert len(ctx_a.decorated_pages) == 1 + assert ctx_a.decorated_pages[0][0] is a_page + + with RegistrationContext() as ctx_b: + page(route="/b")(b_page) + assert len(ctx_b.decorated_pages) == 1 + assert ctx_b.decorated_pages[0][0] is b_page + + assert ctx_a.decorated_pages != ctx_b.decorated_pages + + +def test_custom_components_isolated_between_contexts(): + """@custom_component registrations in one context do not leak to another.""" + from reflex_base.components.component import custom_component + + import reflex as rx + + def _tag_component_fn(prop1: str, prop2: int) -> rx.Component: + return rx.text(prop1) + + with RegistrationContext() as ctx_a: + custom_component(_tag_component_fn) + assert "TagComponentFn" in ctx_a.custom_components + + with RegistrationContext() as ctx_b: + assert ctx_b.custom_components == {} + + +def test_memo_definitions_isolated_between_contexts(): + """@rx._x.memo registrations in one context do not leak to another.""" + import reflex as rx + + with RegistrationContext() as ctx_a: + + @rx._x.memo + def greet(name: rx.Var[str]) -> rx.Var[str]: + return name.to(str) + + assert "greet" in ctx_a.memo_definitions + + with RegistrationContext() as ctx_b: + assert ctx_b.memo_definitions == {} + + +def test_bundled_libraries_isolated_between_contexts(): + """bundle_library appends to the current context only.""" + from reflex_base.components.dynamic import bundle_library + + with RegistrationContext() as ctx_a: + initial_len = len(ctx_a.bundled_libraries) + bundle_library("some-extra-lib") + assert "some-extra-lib" in ctx_a.bundled_libraries + assert len(ctx_a.bundled_libraries) == initial_len + 1 + + with RegistrationContext() as ctx_b: + assert "some-extra-lib" not in ctx_b.bundled_libraries diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 077b7babb07..31810b19948 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -2128,6 +2128,7 @@ def test_app_wrap_compile_theme( react_strict_mode: bool, compilable_app: tuple[App, Path], mocker: MockerFixture, + clean_registration_context, ): """Test that the radix theme component wraps the app. @@ -2135,6 +2136,8 @@ def test_app_wrap_compile_theme( react_strict_mode: Whether to use React Strict Mode. compilable_app: compilable_app fixture. mocker: pytest mocker object. + clean_registration_context: Fresh registration context so the + `_get_config` mock below is not masked by a cached config. """ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode) mocker.patch("reflex_base.config._get_config", return_value=conf) @@ -2182,6 +2185,7 @@ def test_app_wrap_priority( react_strict_mode: bool, compilable_app: tuple[App, Path], mocker: MockerFixture, + clean_registration_context, ): """Test that the app wrap components are wrapped in the correct order. @@ -2189,6 +2193,8 @@ def test_app_wrap_priority( react_strict_mode: Whether to use React Strict Mode. compilable_app: compilable_app fixture. mocker: pytest mocker object. + clean_registration_context: Fresh registration context so the + `_get_config` mock below is not masked by a cached config. """ conf = rx.Config(app_name="testing", react_strict_mode=react_strict_mode) mocker.patch("reflex_base.config._get_config", return_value=conf) diff --git a/tests/units/test_page.py b/tests/units/test_page.py index 07cb5c4e2a4..61ba5c886e0 100644 --- a/tests/units/test_page.py +++ b/tests/units/test_page.py @@ -1,32 +1,43 @@ -from reflex_base.config import get_config +from reflex_base.registry import RegistrationContext from reflex import text -from reflex.page import DECORATED_PAGES, page +from reflex.page import page -def test_page_decorator(): +def test_page_decorator(clean_registration_context: RegistrationContext): + """@page stores the decorated function on the current registration context. + + Args: + clean_registration_context: A fresh registration context. + """ + def foo_(): return text("foo") - DECORATED_PAGES.clear() - assert len(DECORATED_PAGES) == 0 + assert clean_registration_context.decorated_pages == [] decorated_foo_ = page()(foo_) assert decorated_foo_ == foo_ - assert len(DECORATED_PAGES) == 1 - page_data = DECORATED_PAGES.get(get_config().app_name, [])[0][1] + assert len(clean_registration_context.decorated_pages) == 1 + _, page_data = clean_registration_context.decorated_pages[0] assert page_data == {} - DECORATED_PAGES.clear() -def test_page_decorator_with_kwargs(): +def test_page_decorator_with_kwargs( + clean_registration_context: RegistrationContext, +): + """@page preserves all kwargs on the current registration context. + + Args: + clean_registration_context: A fresh registration context. + """ + def foo_(): return text("foo") def load_foo(): return [] - DECORATED_PAGES.clear() - assert len(DECORATED_PAGES) == 0 + assert clean_registration_context.decorated_pages == [] decorated_foo_ = page( route="foo", title="Foo", @@ -37,8 +48,8 @@ def load_foo(): on_load=load_foo, )(foo_) assert decorated_foo_ == foo_ - assert len(DECORATED_PAGES) == 1 - page_data = DECORATED_PAGES.get(get_config().app_name, [])[0][1] + assert len(clean_registration_context.decorated_pages) == 1 + _, page_data = clean_registration_context.decorated_pages[0] assert page_data == { "description": "Foo description", "image": "foo.png", @@ -48,5 +59,3 @@ def load_foo(): "script_tags": ["foo-script"], "title": "Foo", } - - DECORATED_PAGES.clear() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index cb2bf87bdd2..678caeaafe4 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3701,7 +3701,7 @@ def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_ with chdir(proj_root): # reload config for each parameter to avoid stale values - reflex_base.config.get_config(reload=True) + reflex_base.config.reload_config() state_manager = StateManagerRedis(redis=mock_redis()) assert state_manager.lock_expiration == expected_values[0] # pyright: ignore [reportAttributeAccessIssue] @@ -3737,7 +3737,7 @@ def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( with chdir(proj_root): # reload config for each parameter to avoid stale values - reflex_base.config.get_config(reload=True) + reflex_base.config.reload_config() with pytest.raises(InvalidLockWarningThresholdError): StateManagerRedis(redis=mock_redis()) @@ -3762,7 +3762,7 @@ def test_state_manager_create_respects_explicit_memory_mode_with_redis_url( monkeypatch.setenv("REFLEX_REDIS_URL", "redis://localhost:6379") with chdir(proj_root): - reflex_base.config.get_config(reload=True) + reflex_base.config.reload_config() monkeypatch.setattr(prerequisites, "get_redis", mock_redis) state_manager = StateManager.create() assert isinstance(state_manager, StateManagerMemory) @@ -3786,7 +3786,7 @@ def test_auto_setters_off(tmp_path): with chdir(proj_root): # reload config for each parameter to avoid stale values - reflex_base.config.get_config(reload=True) + reflex_base.config.reload_config() from reflex.state import State class TestState(State): @@ -3813,7 +3813,7 @@ def test_auto_setters_on(tmp_path): with chdir(proj_root): # reload config for each parameter to avoid stale values - reflex_base.config.get_config(reload=True) + reflex_base.config.reload_config() from reflex.state import State class TestState(State): diff --git a/tests/units/test_testing.py b/tests/units/test_testing.py index e1f6c20576b..0e74d2211f9 100644 --- a/tests/units/test_testing.py +++ b/tests/units/test_testing.py @@ -6,13 +6,12 @@ import pytest import reflex_base.config -from reflex_base.components.component import CUSTOM_COMPONENTS from reflex_base.constants import IS_WINDOWS +from reflex_base.registry import RegistrationContext import reflex.reflex as reflex_cli import reflex.testing as reflex_testing import reflex.utils.prerequisites -from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.testing import AppHarness @@ -69,10 +68,10 @@ def harness_mocks(monkeypatch): ) ) - monkeypatch.setattr(reflex_testing, "get_config", lambda reload=False: fake_config) - monkeypatch.setattr( - reflex_base.config, "get_config", lambda reload=False: fake_config - ) + monkeypatch.setattr(reflex_testing, "get_config", lambda: fake_config) + monkeypatch.setattr(reflex_testing, "reload_config", lambda: fake_config) + monkeypatch.setattr(reflex_base.config, "get_config", lambda: fake_config) + monkeypatch.setattr(reflex_base.config, "reload_config", lambda: fake_config) monkeypatch.setattr( reflex.utils.prerequisites, "get_and_validate_app", @@ -85,21 +84,28 @@ def harness_mocks(monkeypatch): ) -def test_app_harness_initialize_clears_memo_registries( - tmp_path, preserve_memo_registries, harness_mocks, monkeypatch +def test_app_harness_initialize_isolates_memo_registries( + tmp_path, harness_mocks, monkeypatch ): - """Ensure app initialization clears leaked memo registries. + """Each AppHarness initialization yields a fresh registration context. + + Entries registered in a prior context do not leak into the new harness's + registrations. Args: tmp_path: pytest tmp_path fixture - preserve_memo_registries: restores global memo registries after the test harness_mocks: shared AppHarness mock setup monkeypatch: pytest monkeypatch fixture """ monkeypatch.setattr(reflex_cli, "_init", lambda **kwargs: None) - CUSTOM_COMPONENTS["FooComponent"] = mock.sentinel.component - EXPERIMENTAL_MEMOS["format_value"] = mock.sentinel.memo + outer = RegistrationContext.ensure_context() + # Pin a clean base so pollution on the outer context does not seed new harnesses. + base = RegistrationContext() + monkeypatch.setattr(AppHarness, "_base_registration_context", base) + + outer.custom_components["FooComponent"] = mock.sentinel.component + outer.memo_definitions["format_value"] = mock.sentinel.memo harness = AppHarness.create( root=tmp_path / "memo_app", @@ -107,21 +113,31 @@ def test_app_harness_initialize_clears_memo_registries( app_name="memo_app", ) harness.app_module_path.parent.mkdir(parents=True, exist_ok=True) - harness._initialize_app() - - assert "FooComponent" not in CUSTOM_COMPONENTS - assert "format_value" not in EXPERIMENTAL_MEMOS - harness_mocks.get_and_validate_app.assert_called_once_with(reload=True) + try: + harness._initialize_app() + + new_ctx = RegistrationContext.get() + assert new_ctx is not outer + assert "FooComponent" not in new_ctx.custom_components + assert "format_value" not in new_ctx.memo_definitions + harness_mocks.get_and_validate_app.assert_called_once_with(reload=True) + finally: + # `_initialize_app` attaches a new context without a matching __exit__. + # Restore the outer context so other tests do not observe the leaked one. + if harness._registry_token is not None: + RegistrationContext.reset(harness._registry_token) + # Clean up the sentinels we added to `outer`. + outer.custom_components.pop("FooComponent", None) + outer.memo_definitions.pop("format_value", None) def test_app_harness_initialize_reloads_existing_imported_app( - tmp_path, preserve_memo_registries, harness_mocks, monkeypatch + tmp_path, harness_mocks, monkeypatch ): """Ensure pre-existing imported apps are reloaded after memo registry reset. Args: tmp_path: pytest tmp_path fixture - preserve_memo_registries: restores global memo registries after the test harness_mocks: shared AppHarness mock setup monkeypatch: pytest monkeypatch fixture """ From c3441afd4ddf41ee8cea63c289e7868998f2243b Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 24 Apr 2026 19:51:23 +0000 Subject: [PATCH 2/7] Restore threading lock around _load_config sys.path mutation Addresses Greptile P2 review comment on #6382. The previous get_config guarded sys.path clear/restore with a module-level RLock; this restores the same protection around _load_config so concurrent calls (e.g. from multiple AppHarness instances or reload_config on background threads) cannot leave sys.path in a corrupt state. --- .../reflex-base/src/reflex_base/config.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/packages/reflex-base/src/reflex_base/config.py b/packages/reflex-base/src/reflex_base/config.py index a7923afc76a..584fb382938 100644 --- a/packages/reflex-base/src/reflex_base/config.py +++ b/packages/reflex-base/src/reflex_base/config.py @@ -4,6 +4,7 @@ import importlib import os import sys +import threading import urllib.parse from collections.abc import Sequence from importlib.util import find_spec @@ -657,27 +658,32 @@ def _get_config() -> Config: return rxconfig.config +# Protect sys.path from concurrent modification during config loading. +_load_config_lock = threading.RLock() + + def _load_config() -> Config: """Load the config from rxconfig.py with cwd on sys.path. Returns: The app config. """ - orig_sys_path = sys.path.copy() - sys.path.clear() - sys.path.append(str(Path.cwd())) - try: - return _get_config() - except Exception: - sys.path.extend(orig_sys_path) - return _get_config() - finally: - extra_paths = [ - p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) - ] + with _load_config_lock: + orig_sys_path = sys.path.copy() sys.path.clear() - sys.path.extend(extra_paths) - sys.path.extend(orig_sys_path) + sys.path.append(str(Path.cwd())) + try: + return _get_config() + except Exception: + sys.path.extend(orig_sys_path) + return _get_config() + finally: + extra_paths = [ + p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) + ] + sys.path.clear() + sys.path.extend(extra_paths) + sys.path.extend(orig_sys_path) def get_config() -> Config: From 9dc0aa6c735a3e48a5b3cbf03b1a80f63093a805 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 27 Apr 2026 14:22:04 -0700 Subject: [PATCH 3/7] Register App() in RegistrationContext --- .../src/reflex_base/context/base.py | 4 +- .../reflex-base/src/reflex_base/registry.py | 57 ++++++++++++- reflex/app.py | 2 + reflex/testing.py | 3 +- tests/units/conftest.py | 20 ++++- tests/units/reflex_base/test_registry.py | 84 ++++++++++++++++++- tests/units/test_app.py | 3 +- 7 files changed, 164 insertions(+), 9 deletions(-) diff --git a/packages/reflex-base/src/reflex_base/context/base.py b/packages/reflex-base/src/reflex_base/context/base.py index 7bb28d4864c..f773c2bd0eb 100644 --- a/packages/reflex-base/src/reflex_base/context/base.py +++ b/packages/reflex-base/src/reflex_base/context/base.py @@ -9,7 +9,7 @@ from typing_extensions import Self -@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True, eq=False) class BaseContext: """Base context class that acts as an async context manager to set the context var.""" @@ -67,7 +67,7 @@ def __enter__(self) -> Self: def __exit__(self, *exc_info): """Exit the context.""" - if (token := self._attached_context_token.pop(self)) is not None: + if (token := self._attached_context_token.pop(self, None)) is not None: self._context_var.reset(token) def ensure_context_attached(self): diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py index fb7b4a089d2..7f868e47c3f 100644 --- a/packages/reflex-base/src/reflex_base/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -8,11 +8,12 @@ from typing_extensions import Self from reflex_base.context.base import BaseContext -from reflex_base.utils.exceptions import StateValueError +from reflex_base.utils.exceptions import ReflexRuntimeError, StateValueError if TYPE_CHECKING: from collections.abc import Callable + from reflex.app import App from reflex.state import BaseState from reflex_base.components.component import StatefulComponent from reflex_base.config import Config @@ -82,6 +83,60 @@ class RegistrationContext(BaseContext): default_factory=_default_bundled_libraries, repr=False, ) + _app: App | None = dataclasses.field(default=None, repr=False) + + @property + def app(self) -> App | None: + """Get the App instance associated with this context. + + Returns: + The App instance, or None if no app has been registered. + """ + return self._app + + def _set_app(self, app: App) -> None: + """Associate an App instance with this context. + + Args: + app: The App instance to register. + + Raises: + ReflexRuntimeError: If an App is already registered with this context. + """ + if self._app is not None: + msg = ( + "A RegistrationContext can only be associated with a single App " + "instance. To create another App, call `.fork()` on the current " + "RegistrationContext to obtain a fresh context that preserves " + "existing registrations, or instantiate a new RegistrationContext " + "and set it as the current context before instantiating the new App." + ) + raise ReflexRuntimeError(msg) + object.__setattr__(self, "_app", app) + + def fork(self) -> Self: + """Create a copy of this context with `_app` reset to None. + + Existing registrations (event handlers, base states, decorated pages, etc.) + are shallow-copied so the fork can evolve independently while preserving + already-registered classes. + + Returns: + A new RegistrationContext with the same registrations but no app. + """ + return type(self)( + event_handlers=dict(self.event_handlers), + base_states=dict(self.base_states), + base_state_substates={ + k: set(v) for k, v in self.base_state_substates.items() + }, + tag_to_stateful_component=dict(self.tag_to_stateful_component), + config=self.config, + decorated_pages=list(self.decorated_pages), + custom_components=dict(self.custom_components), + memo_definitions=dict(self.memo_definitions), + bundled_libraries=list(self.bundled_libraries), + ) def _set_config(self, config: Config) -> None: """Set the config for this context. diff --git a/reflex/app.py b/reflex/app.py index c2b2c79cef8..48f0832f9c2 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -452,6 +452,8 @@ def __post_init__(self): msg = "rx.BaseState cannot be subclassed directly. Use rx.State instead" raise ValueError(msg) + self._registration_context._set_app(self) + reload_config() if "breakpoints" in self.style: diff --git a/reflex/testing.py b/reflex/testing.py index 57239aa8e69..dfaac8c4e40 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -21,7 +21,6 @@ import time import types from collections.abc import Callable, Coroutine, Sequence -from copy import deepcopy from http.server import SimpleHTTPRequestHandler from importlib.util import find_spec from pathlib import Path @@ -272,7 +271,7 @@ def _initialize_app(self): AppHarness._base_registration_context = ( RegistrationContext.ensure_context() ) - new_registration_context = deepcopy(AppHarness._base_registration_context) + new_registration_context = AppHarness._base_registration_context.fork() self._registry_token = RegistrationContext.set(new_registration_context) # ensure config and app are reloaded when testing different app config = reload_config() diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 7f50406dbd4..cd16fcc4a6f 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -4,7 +4,6 @@ import traceback import uuid from collections.abc import AsyncGenerator, Generator, Mapping -from copy import deepcopy from typing import Any from unittest import mock @@ -28,6 +27,23 @@ from .states.upload import SubUploadState, UploadState +@pytest.fixture(autouse=True) +def _isolate_app_in_context() -> Generator[None, None, None]: + """Reset the App slot on the active RegistrationContext between tests. + + A RegistrationContext can only host one App instance, but unit tests + repeatedly instantiate `rx.App`, so we clear `_app` around each test + while keeping other registrations shared (matching prior behavior). + + Yields: + None. + """ + ctx = RegistrationContext.ensure_context() + object.__setattr__(ctx, "_app", None) + yield + object.__setattr__(ctx, "_app", None) + + @pytest.fixture def app() -> App: """A base app. @@ -469,7 +485,7 @@ def forked_registration_context() -> Generator[RegistrationContext, None, None]: Yields: The forked RegistrationContext. """ - with deepcopy(RegistrationContext.get()) as ctx: + with RegistrationContext.get().fork() as ctx: yield ctx diff --git a/tests/units/reflex_base/test_registry.py b/tests/units/reflex_base/test_registry.py index fced0d35d4e..0eb2cc14571 100644 --- a/tests/units/reflex_base/test_registry.py +++ b/tests/units/reflex_base/test_registry.py @@ -5,7 +5,7 @@ import pytest from reflex_base.config import Config, get_config, reload_config from reflex_base.registry import RegisteredEventHandler, RegistrationContext -from reflex_base.utils.exceptions import StateValueError +from reflex_base.utils.exceptions import ReflexRuntimeError, StateValueError from reflex.testing import chdir @@ -297,6 +297,88 @@ def greet(name: rx.Var[str]) -> rx.Var[str]: assert ctx_b.memo_definitions == {} +def test_app_registers_on_instantiation( + clean_registration_context: RegistrationContext, +): + """Instantiating an rx.App stores it on the active RegistrationContext. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + import reflex as rx + + assert clean_registration_context.app is None + app = rx.App() + assert clean_registration_context.app is app + assert clean_registration_context._app is app + + +def test_second_app_in_same_context_raises( + clean_registration_context: RegistrationContext, +): + """A second rx.App() in the same context raises ReflexRuntimeError. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + import reflex as rx + + rx.App() + with pytest.raises(ReflexRuntimeError, match=r"\.fork\(\)"): + rx.App() + + +def test_fork_clears_app_and_preserves_registrations( + clean_registration_context: RegistrationContext, +): + """fork() returns a new context with _app=None but the same registrations. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + import reflex as rx + from reflex.event import EventHandler + from reflex.state import BaseState + + class ForkState(BaseState): + x: int = 0 + + async def _fn(): + pass + + handler = EventHandler(fn=_fn) + RegistrationContext.register_event_handler(handler) + app = rx.App() + assert clean_registration_context.app is app + + forked = clean_registration_context.fork() + assert forked is not clean_registration_context + assert forked.app is None + assert forked.event_handlers == clean_registration_context.event_handlers + assert forked.event_handlers is not clean_registration_context.event_handlers + assert forked.base_states == clean_registration_context.base_states + assert forked.base_states is not clean_registration_context.base_states + + +def test_fork_allows_new_app(clean_registration_context: RegistrationContext): + """A forked context permits a new App to be instantiated. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + import reflex as rx + + rx.App() + forked = clean_registration_context.fork() + token = RegistrationContext.set(forked) + try: + new_app = rx.App() + finally: + RegistrationContext.reset(token) + assert forked.app is new_app + assert clean_registration_context.app is not new_app + + def test_bundled_libraries_isolated_between_contexts(): """bundle_library appends to the current context only.""" from reflex_base.components.dynamic import bundle_library diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 31810b19948..7985a95ed07 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -2262,7 +2262,8 @@ def test_app_state_determination(): a1 = App() assert a1._state is not None - a2 = App(enable_state=False) + with RegistrationContext.get().fork(): + a2 = App(enable_state=False) assert a2._state is None From 6648d1779fe1aa0209c68066f24e407774b343a5 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 27 Apr 2026 15:01:52 -0700 Subject: [PATCH 4/7] Use RegistrationContext.app instead of get_app internally --- .../src/reflex_base/components/component.py | 16 +++----- .../src/reflex_base/components/dynamic.py | 4 +- .../reflex-base/src/reflex_base/config.py | 5 +-- .../src/reflex_base/event/__init__.py | 3 +- .../reflex_components_core/core/_upload.py | 2 +- pyi_hashes.json | 2 +- reflex/compiler/compiler.py | 3 +- reflex/compiler/utils.py | 12 +++--- reflex/experimental/memo.py | 5 +-- reflex/istate/shared.py | 11 ++++-- reflex/page.py | 3 +- reflex/state.py | 37 ++++--------------- reflex/utils/prerequisites.py | 3 +- tests/integration/test_lifespan.py | 14 ++++--- tests/units/experimental/test_memo.py | 16 +++----- .../processor/test_base_state_processor.py | 1 - tests/units/test_app.py | 4 +- tests/units/test_state.py | 2 - 18 files changed, 53 insertions(+), 90 deletions(-) diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index f88d488206f..a043d790d21 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -49,6 +49,7 @@ run_script, unwrap_var_annotation, ) +from reflex_base.registry import RegistrationContext from reflex_base.style import Style, format_as_emotion from reflex_base.utils import console, format, imports, types from reflex_base.utils.imports import ImportDict, ImportVar, ParsedImportDict @@ -2207,12 +2208,11 @@ def get_component(self) -> Component: """ component = self.component_fn(*self.get_prop_vars()) - try: - from reflex.utils.prerequisites import get_and_validate_app - - style = get_and_validate_app().app.style - except Exception: - style = {} + style = ( + app.style + if (app := RegistrationContext.ensure_context().app) is not None + else {} + ) component._add_style_recursive(style) return component @@ -2250,8 +2250,6 @@ def _register_custom_component( Raises: TypeError: If the tag name cannot be determined. """ - from reflex_base.registry import RegistrationContext - dummy_props = { prop: ( Var( @@ -2427,8 +2425,6 @@ def create(cls, component: Component) -> StatefulComponent | None: """ from reflex_components_core.core.foreach import Foreach - from reflex_base.registry import RegistrationContext - if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. return None diff --git a/packages/reflex-base/src/reflex_base/components/dynamic.py b/packages/reflex-base/src/reflex_base/components/dynamic.py index fa709d5db47..0a0d9d6baf6 100644 --- a/packages/reflex-base/src/reflex_base/components/dynamic.py +++ b/packages/reflex-base/src/reflex_base/components/dynamic.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Union from reflex_base import constants +from reflex_base.registry import RegistrationContext from reflex_base.utils import imports from reflex_base.utils.exceptions import DynamicComponentMissingLibraryError from reflex_base.utils.format import format_library_name @@ -35,8 +36,6 @@ def bundle_library(component: Union["Component", str]): Raises: DynamicComponentMissingLibraryError: Raised when a dynamic component is missing a library. """ - from reflex_base.registry import RegistrationContext - bundled = RegistrationContext.ensure_context().bundled_libraries if isinstance(component, str): bundled.append(component) @@ -66,7 +65,6 @@ def make_component(component: Component) -> str: from reflex_components_core.base.bare import Bare from reflex.compiler import compiler, templates, utils - from reflex_base.registry import RegistrationContext libs_in_window = RegistrationContext.ensure_context().bundled_libraries diff --git a/packages/reflex-base/src/reflex_base/config.py b/packages/reflex-base/src/reflex_base/config.py index 584fb382938..3097913b5a9 100644 --- a/packages/reflex-base/src/reflex_base/config.py +++ b/packages/reflex-base/src/reflex_base/config.py @@ -27,6 +27,7 @@ from reflex_base.environment import environment as environment from reflex_base.plugins import Plugin from reflex_base.plugins.sitemap import SitemapPlugin +from reflex_base.registry import RegistrationContext from reflex_base.utils import console from reflex_base.utils.exceptions import ConfigError @@ -696,8 +697,6 @@ def get_config() -> Config: Returns: The app config. """ - from reflex_base.registry import RegistrationContext - ctx = RegistrationContext.ensure_context() config = ctx.config if config is None: @@ -715,8 +714,6 @@ def reload_config() -> Config: Returns: The freshly loaded app config. """ - from reflex_base.registry import RegistrationContext - ctx = RegistrationContext.ensure_context() config = _load_config() ctx._set_config(config) diff --git a/packages/reflex-base/src/reflex_base/event/__init__.py b/packages/reflex-base/src/reflex_base/event/__init__.py index 8762e694d18..4cd438a4af1 100644 --- a/packages/reflex-base/src/reflex_base/event/__init__.py +++ b/packages/reflex-base/src/reflex_base/event/__init__.py @@ -28,6 +28,7 @@ from reflex_base import constants from reflex_base.components.field import BaseField from reflex_base.constants.compiler import CompileVars, Hooks, Imports +from reflex_base.registry import RegistrationContext from reflex_base.utils import format from reflex_base.utils.decorator import once from reflex_base.utils.exceptions import ( @@ -81,8 +82,6 @@ class Event: @property def state_cls(self) -> "type[BaseState]": """The state class for the event.""" - from reflex_base.registry import RegistrationContext - substate_name = self.name.rpartition(".")[0] return RegistrationContext.get().base_states[substate_name] diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 0674e071fbd..037bc665bc1 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, BinaryIO, cast from python_multipart.multipart import MultipartParser, parse_options_header +from reflex_base.registry import RegistrationContext from reflex_base.utils import exceptions from reflex_base.utils.format import json_dumps from reflex_base.utils.streaming_response import DisconnectAwareStreamingResponse @@ -661,7 +662,6 @@ async def upload_file(request: Request): resolve_upload_chunk_handler_param, resolve_upload_handler_param, ) - from reflex_base.registry import RegistrationContext token, handler_name = _require_upload_headers(request) registered_event_handler = RegistrationContext.get().event_handlers[ diff --git a/pyi_hashes.json b/pyi_hashes.json index 4e5b8af7a20..c9f18d1ba7f 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -120,5 +120,5 @@ "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "2c5fadcc014056f041cd4d916137d9e7", "reflex/__init__.pyi": "3a9bb8544cbc338ffaf0a5927d9156df", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", - "reflex/experimental/memo.pyi": "69479b0af7bd47b642337d116a19b1f8" + "reflex/experimental/memo.pyi": "f1f2df654212f1dad6e81fa2071ed12b" } diff --git a/reflex/compiler/compiler.py b/reflex/compiler/compiler.py index a05e8acc04e..fcd209a8f14 100644 --- a/reflex/compiler/compiler.py +++ b/reflex/compiler/compiler.py @@ -20,6 +20,7 @@ from reflex_base.constants.compiler import PageNames, ResetStylesheet from reflex_base.constants.state import FIELD_MARKER from reflex_base.environment import environment +from reflex_base.registry import RegistrationContext from reflex_base.style import SYSTEM_COLOR_MODE from reflex_base.utils.exceptions import ReflexError from reflex_base.utils.format import to_title_case @@ -88,8 +89,6 @@ def _compile_app(app_root: Component) -> str: Returns: The compiled app. """ - from reflex_base.registry import RegistrationContext - window_libraries = [ (_normalize_library_name(name), name) for name in RegistrationContext.ensure_context().bundled_libraries diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index bb812c67c5f..c0eef445725 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -16,6 +16,7 @@ from reflex_base import constants from reflex_base.components.component import Component, ComponentStyle, CustomComponent from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER, FIELD_MARKER +from reflex_base.registry import RegistrationContext from reflex_base.style import Style from reflex_base.utils import format, imports from reflex_base.utils.imports import ImportVar, ParsedImportDict @@ -373,12 +374,11 @@ def _apply_component_style_for_compile(component: Component) -> Component: Returns: The styled component tree. """ - try: - from reflex.utils.prerequisites import get_and_validate_app - - style = get_and_validate_app().app.style - except Exception: - style = {} + style = ( + app.style + if (app := RegistrationContext.ensure_context().app) is not None + else {} + ) component._add_style_recursive(style) return component diff --git a/reflex/experimental/memo.py b/reflex/experimental/memo.py index 8b583829f9d..772e9cd8dcf 100644 --- a/reflex/experimental/memo.py +++ b/reflex/experimental/memo.py @@ -12,6 +12,7 @@ from reflex_base.components.component import Component from reflex_base.constants.compiler import SpecialAttributes from reflex_base.constants.state import CAMEL_CASE_MEMO_MARKER +from reflex_base.registry import RegistrationContext from reflex_base.utils import format from reflex_base.utils.imports import ImportVar from reflex_base.utils.types import safe_issubclass @@ -181,8 +182,6 @@ def _register_memo_definition(definition: ExperimentalMemoDefinition) -> None: Raises: ValueError: If another memo already compiles to the same exported name. """ - from reflex_base.registry import RegistrationContext - memos = RegistrationContext.ensure_context().memo_definitions key = _memo_registry_key(definition) if (existing := memos.get(key)) is not None and ( @@ -354,8 +353,6 @@ def _validate_var_return_expr(return_expr: Var, func_name: str) -> None: ) raise TypeError(msg) - from reflex_base.registry import RegistrationContext - bundled = RegistrationContext.ensure_context().bundled_libraries for lib in dict(var_data.imports): if not lib: diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 41b1f519cd3..37e753987b0 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -7,6 +7,7 @@ from reflex_base.constants import ROUTER_DATA from reflex_base.event import Event, get_hydrate_event +from reflex_base.registry import RegistrationContext from reflex_base.utils import console from reflex_base.utils.exceptions import ReflexRuntimeError from typing_extensions import Self @@ -49,9 +50,9 @@ def _do_update_other_tokens( Returns: The list of asyncio tasks created to perform the updates. """ - from reflex.utils.prerequisites import get_app - - app = get_app().app + if (app := RegistrationContext.get().app) is None: + msg = "No App is registered with the active RegistrationContext." + raise RuntimeError(msg) async def _update_client(token: str): async with app.modify_state( @@ -61,9 +62,11 @@ async def _update_client(token: str): pass tasks = [] + if (event_namespace := app.event_namespace) is None: + return tasks for affected_token in affected_tokens: # Don't send updates for disconnected clients. - if affected_token not in app.event_namespace._token_manager.token_to_socket: + if affected_token not in event_namespace._token_manager.token_to_socket: continue # TODO: remove disconnected clients after some time. t = asyncio.create_task(_update_client(affected_token)) diff --git a/reflex/page.py b/reflex/page.py index 637073ea560..5b636bc0598 100644 --- a/reflex/page.py +++ b/reflex/page.py @@ -5,6 +5,8 @@ import sys from typing import TYPE_CHECKING +from reflex_base.registry import RegistrationContext + if TYPE_CHECKING: from collections.abc import Callable from typing import Any @@ -42,7 +44,6 @@ def page( Returns: The decorated function. """ - from reflex_base.registry import RegistrationContext def decorator(render_fn: Callable): kwargs: dict[str, Any] = {} diff --git a/reflex/state.py b/reflex/state.py index c09d68658f8..9f99983d4ac 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -37,6 +37,7 @@ EventSpec, call_script, ) +from reflex_base.registry import RegistrationContext from reflex_base.utils.exceptions import ( ComputedVarShadowsBaseVarsError, ComputedVarShadowsStateVarError, @@ -74,7 +75,7 @@ from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy from reflex.istate.proxy import MutableProxy, is_mutable_type from reflex.istate.storage import ClientStorageBase -from reflex.utils import console, format, prerequisites, types +from reflex.utils import console, format, types from reflex.utils.exec import is_testing_env if TYPE_CHECKING: @@ -489,7 +490,6 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): Raises: StateValueError: If a substate class shadows another. """ - from reflex_base.registry import RegistrationContext from reflex_base.utils.exceptions import StateValueError super().__init_subclass__(**kwargs) @@ -953,8 +953,6 @@ def get_substates(cls) -> set[type[BaseState]]: Returns: The substates of the state. """ - from reflex_base.registry import RegistrationContext - return RegistrationContext.get().get_substates(cls) @classmethod @@ -1138,8 +1136,6 @@ def _create_event_handler( Returns: The event handler. """ - from reflex_base.registry import RegistrationContext - # Check if function has stored event_actions from decorator event_actions = getattr(fn, EVENT_ACTIONS_MARKER, {}) @@ -2291,9 +2287,8 @@ def handle_frontend_exception( "window.location.reload();" "}" ) - prerequisites.get_and_validate_app().app.frontend_exception_handler( - Exception(info) - ) + if (app := RegistrationContext.get().app) is not None: + app.frontend_exception_handler(Exception(info)) class UpdateVarsInternalState(State): @@ -2326,9 +2321,6 @@ class OnLoadInternalState(State): This is a separate substate to avoid deserializing the entire state tree for every page navigation. """ - # Cannot properly annotate this as `App` due to circular import issues. - _app_ref: ClassVar[Any] = None - def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | None: """Queue on_load handlers for the current page. @@ -2336,19 +2328,11 @@ def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | No The list of events to queue for on load handling. Raises: - TypeError: If the app reference is not of type App. + RuntimeError: If no App is registered with the active RegistrationContext. """ - from reflex.app import App - - app = type(self)._app_ref or prerequisites.get_and_validate_app().app - if not isinstance(app, App): - msg = ( - f"Expected app to be of type {App.__name__}, got {type(app).__name__}." - ) - raise TypeError(msg) - # Cache the app reference for subsequent calls. - if type(self)._app_ref is None: - type(self)._app_ref = app + if (app := RegistrationContext.get().app) is None: + msg = "No App is registered with the active RegistrationContext." + raise RuntimeError(msg) load_events = app.get_load_events(self.router.url.path) if not load_events: self.is_hydrated = True @@ -2539,11 +2523,6 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ - from reflex_base.registry import RegistrationContext - - # Reset the _app_ref of OnLoadInternalState to avoid stale references. - if state is OnLoadInternalState: - state._app_ref = None # Clean out all potentially dirty states of reloaded modules. for pd_state in tuple(state._potentially_dirty_states): with contextlib.suppress(ValueError): diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index eebc310f1b7..83aa3226724 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -22,6 +22,7 @@ from reflex_base.config import Config, get_config from reflex_base.constants.base import RunningMode from reflex_base.environment import environment +from reflex_base.registry import RegistrationContext from reflex_base.utils.decorator import once from reflex import model @@ -199,8 +200,6 @@ def get_app(reload: bool = False) -> ModuleType: else config.app_module ) if reload: - from reflex_base.registry import RegistrationContext - from reflex.state import reload_state_module # Reset rx.State subclasses to avoid conflict when reloading. diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 1f3185db86c..6b99f0b212e 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -25,6 +25,8 @@ def LifespanApp( import asyncio from contextlib import asynccontextmanager + from reflex_base.registry import RegistrationContext + import reflex as rx from reflex.istate.manager.token import BaseStateToken @@ -69,9 +71,9 @@ async def raw_asyncio_task_coro(): @asynccontextmanager async def assert_register_blocked_during_lifespan(app): """Negative test: registering a task after lifespan has started must raise.""" - from reflex.utils.prerequisites import get_app - - reflex_app = get_app().app + if (reflex_app := RegistrationContext.get().app) is None: + msg = "No App is registered with the active RegistrationContext." + raise RuntimeError(msg) task = asyncio.create_task(raw_asyncio_task_coro(), name="raw_asyncio_task") try: reflex_app.register_lifespan_task(task) @@ -113,9 +115,9 @@ def tick(self, date): pass async def modify_state_task(): - from reflex.utils.prerequisites import get_app - - reflex_app = get_app().app + if (reflex_app := RegistrationContext.get().app) is None: + msg = "No App is registered with the active RegistrationContext." + raise RuntimeError(msg) try: while True: for token in list(connected_tokens): diff --git a/tests/units/experimental/test_memo.py b/tests/units/experimental/test_memo.py index 7988d8c1c00..eb82fff7cfe 100644 --- a/tests/units/experimental/test_memo.py +++ b/tests/units/experimental/test_memo.py @@ -401,9 +401,7 @@ def wrapper() -> rx.Component: assert "inner" in imports -def test_compile_experimental_component_memo_does_not_mutate_definition( - monkeypatch: pytest.MonkeyPatch, -): +def test_compile_experimental_component_memo_does_not_mutate_definition(): """Experimental component memo compilation should not mutate stored components.""" @rx._x.memo @@ -414,13 +412,11 @@ def wrapper() -> rx.Component: assert isinstance(definition, ExperimentalMemoComponentDefinition) assert definition.component.style == Style() - monkeypatch.setattr( - "reflex.utils.prerequisites.get_and_validate_app", - lambda: SimpleNamespace( - app=SimpleNamespace( - style={type(definition.component): Style({"color": "red"})} - ) - ), + ctx = RegistrationContext.ensure_context() + object.__setattr__( + ctx, + "_app", + SimpleNamespace(style={type(definition.component): Style({"color": "red"})}), ) render, _ = compiler_utils.compile_experimental_component_memo(definition) diff --git a/tests/units/reflex_base/event/processor/test_base_state_processor.py b/tests/units/reflex_base/event/processor/test_base_state_processor.py index e414369a40e..a9e8b22b122 100644 --- a/tests/units/reflex_base/event/processor/test_base_state_processor.py +++ b/tests/units/reflex_base/event/processor/test_base_state_processor.py @@ -128,7 +128,6 @@ class MyState(State): def noop(self): pass - OnLoadInternalState._app_ref = None app = app_module_mock.app = App() assert real_base_state_processor._root_context is not None app._state_manager = real_base_state_processor._root_context.state_manager diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 7985a95ed07..f4e72079e30 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -28,6 +28,7 @@ from reflex_base.vars.base import computed_var from reflex_components_core.base.bare import Bare from reflex_components_core.base.fragment import Fragment +from reflex_components_core.core._upload import upload from reflex_components_radix.themes.typography.text import Text from starlette.applications import Starlette from starlette.datastructures import FormData, Headers, UploadFile @@ -37,7 +38,7 @@ import reflex as rx from reflex import AdminDash, constants -from reflex.app import App, ComponentCallable, upload +from reflex.app import App, ComponentCallable from reflex.environment import environment from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory @@ -1908,7 +1909,6 @@ async def test_dynamic_route_var_route_change_completed_on_load( emitted_deltas: List to store emitted deltas. emitted_events: List to store emitted events. """ - OnLoadInternalState._app_ref = None arg_name = "dynamic" route = f"test/[{arg_name}]" app = app_module_mock.app = App() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 678caeaafe4..7b6d7f1a42d 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3186,7 +3186,6 @@ async def test_preprocess( mock_base_state_event_processor: The event processor. emitted_deltas: List to capture emitted deltas. """ - OnLoadInternalState._app_ref = None app = app_module_mock.app = App(_state=State) app._state_manager = mock_root_event_context.state_manager @@ -3250,7 +3249,6 @@ async def test_preprocess_multiple_load_events( mock_base_state_event_processor: The event processor. emitted_deltas: List to capture emitted deltas. """ - OnLoadInternalState._app_ref = None app = app_module_mock.app = App(_state=State) app._state_manager = mock_root_event_context.state_manager From 5836732b168c90ca10cfe79a5b6be466c2dcf804 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 27 Apr 2026 15:30:14 -0700 Subject: [PATCH 5/7] RegistrationContext properties app and config do not return None Accessing these properties raises a ReflexRuntimeError if the value was None at the time it is accessed. This cleans up LBYL None checking wherever these value are accessed. `RegistrationContext.fork` now does NOT copy the Config, allowing both Config and App to be loaded/reloaded in a forked registration context. --- .../src/reflex_base/components/component.py | 2 +- .../reflex-base/src/reflex_base/config.py | 8 ++-- .../reflex-base/src/reflex_base/registry.py | 37 +++++++++++++++---- reflex/compiler/utils.py | 2 +- reflex/istate/shared.py | 4 +- reflex/state.py | 13 ++----- tests/integration/test_lifespan.py | 8 +--- tests/units/reflex_base/test_registry.py | 14 ++++--- 8 files changed, 50 insertions(+), 38 deletions(-) diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index a043d790d21..d5016737998 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -2210,7 +2210,7 @@ def get_component(self) -> Component: style = ( app.style - if (app := RegistrationContext.ensure_context().app) is not None + if (app := RegistrationContext.ensure_context()._app) is not None else {} ) diff --git a/packages/reflex-base/src/reflex_base/config.py b/packages/reflex-base/src/reflex_base/config.py index 3097913b5a9..01a127a9beb 100644 --- a/packages/reflex-base/src/reflex_base/config.py +++ b/packages/reflex-base/src/reflex_base/config.py @@ -698,11 +698,9 @@ def get_config() -> Config: The app config. """ ctx = RegistrationContext.ensure_context() - config = ctx.config - if config is None: - config = _load_config() - ctx._set_config(config) - return config + if ctx._config is None: + ctx._set_config(_load_config()) + return ctx.config def reload_config() -> Config: diff --git a/packages/reflex-base/src/reflex_base/registry.py b/packages/reflex-base/src/reflex_base/registry.py index 7f868e47c3f..b630637c64b 100644 --- a/packages/reflex-base/src/reflex_base/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -66,7 +66,7 @@ class RegistrationContext(BaseContext): default_factory=dict, repr=False, ) - config: Config | None = dataclasses.field(default=None, repr=False) + _config: Config | None = dataclasses.field(default=None, repr=False) decorated_pages: list[tuple[Callable, dict[str, Any]]] = dataclasses.field( default_factory=list, repr=False, @@ -86,14 +86,35 @@ class RegistrationContext(BaseContext): _app: App | None = dataclasses.field(default=None, repr=False) @property - def app(self) -> App | None: + def app(self) -> App: """Get the App instance associated with this context. Returns: - The App instance, or None if no app has been registered. + The App instance. + + Raises: + ReflexRuntimeError: If no App has been registered with this context. """ + if self._app is None: + msg = "No App is registered with the active RegistrationContext." + raise ReflexRuntimeError(msg) return self._app + @property + def config(self) -> Config: + """Get the Config associated with this context. + + Returns: + The Config instance. + + Raises: + ReflexRuntimeError: If no Config has been loaded for this context. + """ + if self._config is None: + msg = "No Config has been loaded for the active RegistrationContext." + raise ReflexRuntimeError(msg) + return self._config + def _set_app(self, app: App) -> None: """Associate an App instance with this context. @@ -115,14 +136,15 @@ def _set_app(self, app: App) -> None: object.__setattr__(self, "_app", app) def fork(self) -> Self: - """Create a copy of this context with `_app` reset to None. + """Create a copy of this context with `_app` and `_config` reset to None. Existing registrations (event handlers, base states, decorated pages, etc.) are shallow-copied so the fork can evolve independently while preserving - already-registered classes. + already-registered classes. The next call to `get_config()` on the fork + will reload `rxconfig.py` from disk. Returns: - A new RegistrationContext with the same registrations but no app. + A new RegistrationContext with the same registrations but no app or config. """ return type(self)( event_handlers=dict(self.event_handlers), @@ -131,7 +153,6 @@ def fork(self) -> Self: k: set(v) for k, v in self.base_state_substates.items() }, tag_to_stateful_component=dict(self.tag_to_stateful_component), - config=self.config, decorated_pages=list(self.decorated_pages), custom_components=dict(self.custom_components), memo_definitions=dict(self.memo_definitions), @@ -144,7 +165,7 @@ def _set_config(self, config: Config) -> None: Args: config: The config to associate with this context. """ - object.__setattr__(self, "config", config) + object.__setattr__(self, "_config", config) @classmethod def ensure_context(cls) -> Self: diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index c0eef445725..d52f6f22d70 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -376,7 +376,7 @@ def _apply_component_style_for_compile(component: Component) -> Component: """ style = ( app.style - if (app := RegistrationContext.ensure_context().app) is not None + if (app := RegistrationContext.ensure_context()._app) is not None else {} ) diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 37e753987b0..ce099017d96 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -50,9 +50,7 @@ def _do_update_other_tokens( Returns: The list of asyncio tasks created to perform the updates. """ - if (app := RegistrationContext.get().app) is None: - msg = "No App is registered with the active RegistrationContext." - raise RuntimeError(msg) + app = RegistrationContext.get().app async def _update_client(token: str): async with app.modify_state( diff --git a/reflex/state.py b/reflex/state.py index 9f99983d4ac..845b21dfb91 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2287,8 +2287,7 @@ def handle_frontend_exception( "window.location.reload();" "}" ) - if (app := RegistrationContext.get().app) is not None: - app.frontend_exception_handler(Exception(info)) + RegistrationContext.get().app.frontend_exception_handler(Exception(info)) class UpdateVarsInternalState(State): @@ -2326,14 +2325,10 @@ def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | No Returns: The list of events to queue for on load handling. - - Raises: - RuntimeError: If no App is registered with the active RegistrationContext. """ - if (app := RegistrationContext.get().app) is None: - msg = "No App is registered with the active RegistrationContext." - raise RuntimeError(msg) - load_events = app.get_load_events(self.router.url.path) + load_events = RegistrationContext.get().app.get_load_events( + self.router.url.path + ) if not load_events: self.is_hydrated = True return None # Fast path for navigation with no on_load events defined. diff --git a/tests/integration/test_lifespan.py b/tests/integration/test_lifespan.py index 6b99f0b212e..6adac44175c 100644 --- a/tests/integration/test_lifespan.py +++ b/tests/integration/test_lifespan.py @@ -71,9 +71,7 @@ async def raw_asyncio_task_coro(): @asynccontextmanager async def assert_register_blocked_during_lifespan(app): """Negative test: registering a task after lifespan has started must raise.""" - if (reflex_app := RegistrationContext.get().app) is None: - msg = "No App is registered with the active RegistrationContext." - raise RuntimeError(msg) + reflex_app = RegistrationContext.get().app task = asyncio.create_task(raw_asyncio_task_coro(), name="raw_asyncio_task") try: reflex_app.register_lifespan_task(task) @@ -115,9 +113,7 @@ def tick(self, date): pass async def modify_state_task(): - if (reflex_app := RegistrationContext.get().app) is None: - msg = "No App is registered with the active RegistrationContext." - raise RuntimeError(msg) + reflex_app = RegistrationContext.get().app try: while True: for token in list(connected_tokens): diff --git a/tests/units/reflex_base/test_registry.py b/tests/units/reflex_base/test_registry.py index 0eb2cc14571..059f379dab4 100644 --- a/tests/units/reflex_base/test_registry.py +++ b/tests/units/reflex_base/test_registry.py @@ -139,12 +139,14 @@ async def _tmp(): def test_clean_context_has_no_config(clean_registration_context: RegistrationContext): - """A fresh RegistrationContext starts with config=None. + """A fresh RegistrationContext has no cached config and raises on access. Args: clean_registration_context: A fresh, empty registration context. """ - assert clean_registration_context.config is None + assert clean_registration_context._config is None + with pytest.raises(ReflexRuntimeError, match="No Config"): + clean_registration_context.config def _write_rxconfig(path, app_name: str) -> None: @@ -169,7 +171,7 @@ def test_get_config_caches_on_context( """ _write_rxconfig(tmp_path, "ctx_app") with chdir(tmp_path): - assert clean_registration_context.config is None + assert clean_registration_context._config is None first = get_config() assert first is clean_registration_context.config assert first.app_name == "ctx_app" @@ -307,7 +309,9 @@ def test_app_registers_on_instantiation( """ import reflex as rx - assert clean_registration_context.app is None + assert clean_registration_context._app is None + with pytest.raises(ReflexRuntimeError, match="No App"): + clean_registration_context.app app = rx.App() assert clean_registration_context.app is app assert clean_registration_context._app is app @@ -353,7 +357,7 @@ async def _fn(): forked = clean_registration_context.fork() assert forked is not clean_registration_context - assert forked.app is None + assert forked._app is None assert forked.event_handlers == clean_registration_context.event_handlers assert forked.event_handlers is not clean_registration_context.event_handlers assert forked.base_states == clean_registration_context.base_states From 87cb874547d3f433d9b851d703baa772bee759f1 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 27 Apr 2026 16:11:36 -0700 Subject: [PATCH 6/7] When reloading the app, clear the RegistrationContext._app reference Otherwise the second load causes the RegistrationContext to raise an exception, because _app is already set from the initial load. --- reflex/utils/prerequisites.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index 83aa3226724..df2c223b129 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -205,7 +205,9 @@ def get_app(reload: bool = False) -> ModuleType: # Reset rx.State subclasses to avoid conflict when reloading. reload_state_module(module=module) - RegistrationContext.ensure_context().decorated_pages.clear() + reg_ctx = RegistrationContext.ensure_context() + reg_ctx.decorated_pages.clear() + object.__setattr__(reg_ctx, "_app", None) # Reload the app module. importlib.reload(app) From 199d6c16d8c8c46633d22385e4436d61f267ea0f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 27 Apr 2026 17:47:30 -0700 Subject: [PATCH 7/7] restore helpful comments --- packages/reflex-base/src/reflex_base/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/reflex-base/src/reflex_base/config.py b/packages/reflex-base/src/reflex_base/config.py index 01a127a9beb..31d6d0d1a22 100644 --- a/packages/reflex-base/src/reflex_base/config.py +++ b/packages/reflex-base/src/reflex_base/config.py @@ -676,12 +676,15 @@ def _load_config() -> Config: try: return _get_config() except Exception: + # If the module import fails, try to import with the original sys.path. sys.path.extend(orig_sys_path) return _get_config() finally: + # Find any entries added to sys.path by rxconfig.py itself. extra_paths = [ p for p in sys.path if p not in orig_sys_path and p != str(Path.cwd()) ] + # Restore the original sys.path. sys.path.clear() sys.path.extend(extra_paths) sys.path.extend(orig_sys_path)