Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENG-4134]Allow specifying custom app module in rxconfig #4556

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
22 changes: 21 additions & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ env:
PR_TITLE: ${{ github.event.pull_request.title }}

jobs:
example-counter:
example-counter-and-nba-proxy:
env:
OUTPUT_FILE: import_benchmark.json
timeout-minutes: 30
Expand Down Expand Up @@ -119,6 +119,26 @@ jobs:
--benchmark-json "./reflex-examples/counter/${{ env.OUTPUT_FILE }}"
--branch-name "${{ github.head_ref || github.ref_name }}" --pr-id "${{ github.event.pull_request.id }}"
--app-name "counter"
- name: Install requirements for nba proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run uv pip install -r requirements.txt
- name: Install additional dependencies for DB access
run: poetry run uv pip install psycopg
- name: Check export --backend-only before init for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex export --backend-only
- name: Init Website for nba-proxy example
working-directory: ./reflex-examples/nba-proxy
run: |
poetry run reflex init --loglevel debug
- name: Run Website and Check for errors
run: |
# Check that npm is home
npm -v
poetry run bash scripts/integration.sh ./reflex-examples/nba-proxy dev


reflex-web:
strategy:
Expand Down
7 changes: 3 additions & 4 deletions reflex/app_module_for_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from reflex import constants
from reflex.utils import telemetry
from reflex.utils.exec import is_prod_mode
from reflex.utils.prerequisites import get_app
from reflex.utils.prerequisites import get_and_validate_app

if constants.CompileVars.APP != "app":
raise AssertionError("unexpected variable name for 'app'")

telemetry.send("compile")
app_module = get_app(reload=False)
app = getattr(app_module, constants.CompileVars.APP)
app, app_module = get_and_validate_app(reload=False)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages()
Expand All @@ -30,7 +29,7 @@
# ensure only "app" is exposed.
del app_module
del compile_future
del get_app
del get_and_validate_app
del is_prod_mode
del telemetry
del constants
Expand Down
45 changes: 44 additions & 1 deletion reflex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import sys
import threading
import urllib.parse
from importlib.util import find_spec
from importlib.util import find_spec, module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -604,6 +605,9 @@ class Config:
# The name of the app (should match the name of the app directory).
app_name: str

# The path to the app module.
app_module_path: Optional[str] = None

# The log level to use.
loglevel: constants.LogLevel = constants.LogLevel.DEFAULT

Expand Down Expand Up @@ -726,13 +730,52 @@ def __init__(self, *args, **kwargs):
"REDIS_URL is required when using the redis state manager."
)

@staticmethod
def _load_via_spec(path: str) -> ModuleType:
"""Load a module dynamically using its file path.

Args:
path: The path to the module.

Returns:
The loaded module.

Raises:
ConfigError: If the module cannot be loaded.
"""
module_name = Path(path).stem
module_path = Path(path).resolve()
sys.path.insert(0, str(module_path.parent.parent))
Copy link
Collaborator

@masenf masenf Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't like the assumption that the pythonpath entry is module_path.parent.parent.

i think a better approach would make the caller responsible for setting their sys.path / PYTHONPATH, and then providing an importable name.

That would save us this complicated import machinery that makes assumptions that would be hard to undo later, we could just use __import__, like we have been.

spec = spec_from_file_location(module_name, module_path)
if not spec:
raise ConfigError(f"Could not load module from path: {module_path}")
module = module_from_spec(spec)
# Set the package name to the parent directory of the module (for relative imports)
module.__package__ = module_path.parent.name
spec.loader.exec_module(module) # type: ignore
return module

@property
def app_module(self) -> ModuleType | None:
"""Return the app module if `app_module_path` is set.

Returns:
The app module.
"""
return (
self._load_via_spec(self.app_module_path) if self.app_module_path else None
)

@property
def module(self) -> str:
"""Get the module name of the app.

Returns:
The module name.
"""
if self.app_module and self.app_module.__file__:
module_file = Path(self.app_module.__file__)
return f"{module_file.parent.name}.{module_file.stem}"
return ".".join([self.app_name, self.app_name])

def update_from_env(self) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ def get_handler_args(


def fix_events(
events: list[EventHandler | EventSpec] | None,
events: list[EventSpec | EventHandler] | None,
token: str,
router_data: dict[str, Any] | None = None,
) -> list[Event]:
Expand Down
29 changes: 15 additions & 14 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,9 +1759,9 @@ def _as_state_update(
except Exception as ex:
state._clean()

app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)

event_specs = app_instance.backend_exception_handler(ex)
event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
)

if event_specs is None:
return StateUpdate()
Expand Down Expand Up @@ -1871,9 +1871,9 @@ async def _process_event(
except Exception as ex:
telemetry.send_error(ex, context="backend")

app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)

event_specs = app_instance.backend_exception_handler(ex)
event_specs = (
prerequisites.get_and_validate_app().app.backend_exception_handler(ex)
)

yield state._as_state_update(
handler,
Expand Down Expand Up @@ -2383,8 +2383,9 @@ def handle_frontend_exception(self, stack: str, component_stack: str) -> None:
component_stack: The stack trace of the component where the exception occurred.

"""
app_instance = getattr(prerequisites.get_app(), constants.CompileVars.APP)
app_instance.frontend_exception_handler(Exception(stack))
prerequisites.get_and_validate_app().app.frontend_exception_handler(
Exception(stack)
)


class UpdateVarsInternalState(State):
Expand Down Expand Up @@ -2422,15 +2423,16 @@ def on_load_internal(self) -> list[Event | EventSpec] | None:
The list of events to queue for on load handling.
"""
# Do not app._compile()! It should be already compiled by now.
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
load_events = app.get_load_events(self.router.page.path)
load_events = prerequisites.get_and_validate_app().app.get_load_events(
self.router.page.path
)
if not load_events:
self.is_hydrated = True
return # Fast path for navigation with no on_load events defined.
self.is_hydrated = False
return [
*fix_events(
load_events,
cast(list[Union[EventSpec, EventHandler]], load_events),
self.router.session.client_token,
router_data=self.router_data,
),
Expand Down Expand Up @@ -2589,7 +2591,7 @@ def __init__(
"""
super().__init__(state_instance)
# compile is not relevant to backend logic
self._self_app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
self._self_app = prerequisites.get_and_validate_app().app
self._self_substate_path = tuple(state_instance.get_full_name().split("."))
self._self_actx = None
self._self_mutable = False
Expand Down Expand Up @@ -3682,8 +3684,7 @@ def get_state_manager() -> StateManager:
Returns:
The state manager.
"""
app = getattr(prerequisites.get_app(), constants.CompileVars.APP)
return app.state_manager
return prerequisites.get_and_validate_app().app.state_manager


class MutableProxy(wrapt.ObjectProxy):
Expand Down
17 changes: 15 additions & 2 deletions reflex/utils/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,19 @@ def run_backend(
run_uvicorn_backend(host, port, loglevel)


def get_reload_dirs() -> list[str]:
"""Get the reload directories for the backend.

Returns:
The reload directories for the backend.
"""
config = get_config()
reload_dirs = [config.app_name]
if app_module_path := config.app_module_path:
reload_dirs.append(str(Path(app_module_path).resolve().parent.parent))
Comment on lines +251 to +252
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think for this one, we should go up the parent directories until we don't find an __init__.py, this should be the root of the reload dir for an out-of-tree app module.

return reload_dirs


def run_uvicorn_backend(host, port, loglevel: LogLevel):
"""Run the backend in development mode using Uvicorn.

Expand All @@ -256,7 +269,7 @@ def run_uvicorn_backend(host, port, loglevel: LogLevel):
port=port,
log_level=loglevel.value,
reload=True,
reload_dirs=[get_config().app_name],
reload_dirs=get_reload_dirs(),
)


Expand All @@ -281,7 +294,7 @@ def run_granian_backend(host, port, loglevel: LogLevel):
interface=Interfaces.ASGI,
log_level=LogLevels(loglevel.value),
reload=True,
reload_paths=[Path(get_config().app_name)],
reload_paths=get_reload_dirs(),
reload_ignore_dirs=[".web"],
).serve()
except ImportError:
Expand Down
46 changes: 41 additions & 5 deletions reflex/utils/prerequisites.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import sys
import tempfile
import time
import typing
import zipfile
from datetime import datetime
from pathlib import Path
from types import ModuleType
from typing import Callable, List, Optional
from typing import Callable, List, NamedTuple, Optional

import httpx
import typer
Expand All @@ -42,9 +43,19 @@
from reflex.utils.format import format_library_name
from reflex.utils.registry import _get_npm_registry

if typing.TYPE_CHECKING:
from reflex.app import App

CURRENTLY_INSTALLING_NODE = False


class AppInfo(NamedTuple):
"""A tuple containing the app instance and module."""

app: App
module: ModuleType


@dataclasses.dataclass(frozen=True)
class Template:
"""A template for a Reflex app."""
Expand Down Expand Up @@ -291,8 +302,11 @@ def get_app(reload: bool = False) -> ModuleType:
)
module = config.module
sys.path.insert(0, str(Path.cwd()))
app = __import__(module, fromlist=(constants.CompileVars.APP,))

app = (
__import__(module, fromlist=(constants.CompileVars.APP,))
if not config.app_module
else config.app_module
)
if reload:
from reflex.state import reload_state_module

Expand All @@ -308,6 +322,29 @@ def get_app(reload: bool = False) -> ModuleType:
raise


def get_and_validate_app(reload: bool = False) -> AppInfo:
"""Get the app instance based on the default config and validate it.

Args:
reload: Re-import the app module from disk

Returns:
The app instance and the app module.

Raises:
RuntimeError: If the app instance is not an instance of rx.App.
"""
from reflex.app import App

app_module = get_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
if not isinstance(app, App):
raise RuntimeError(
"The app instance in the specified app_module_path in rxconfig must be an instance of rx.App."
)
return AppInfo(app=app, module=app_module)


def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
"""Get the app module based on the default config after first compiling it.

Expand All @@ -318,8 +355,7 @@ def get_compiled_app(reload: bool = False, export: bool = False) -> ModuleType:
Returns:
The compiled app based on the default config.
"""
app_module = get_app(reload=reload)
app = getattr(app_module, constants.CompileVars.APP)
app, app_module = get_and_validate_app(reload=reload)
# For py3.9 compatibility when redis is used, we MUST add any decorator pages
# before compiling the app in a thread to avoid event loop error (REF-2172).
app._apply_decorated_pages()
Expand Down
Loading