diff --git a/.gitignore b/.gitignore index 4a08c5f..e2830d7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Local git worktrees (see: git worktree add) +.worktrees/ + # Generated by Cargo # will have compiled files and executables debug diff --git a/docs/guide/backend.md b/docs/guide/backend.md new file mode 100644 index 0000000..730eee9 --- /dev/null +++ b/docs/guide/backend.md @@ -0,0 +1,401 @@ +# Backend Guide + +Ferro supports SQLite and PostgreSQL through one Python API and an explicit Rust backend layer. Application code still calls `connect()`, defines Pydantic-style models, and uses the query builder. The Rust core decides which typed SQLx driver, SeaQuery dialect, transaction connection, and value conversion rules apply for the active database. + +This guide starts with the user-facing behavior, then explains the implementation details that maintainers need when changing the backend. + +## What The Backend Is + +The backend is the runtime database engine behind Ferro's Python API. It owns: + +- the active database kind, currently SQLite or PostgreSQL +- the typed SQLx connection pool +- SQL execution and row materialization +- transaction-bound typed connections +- backend-specific SQL generation choices +- value binding and hydration rules + +The backend does not introduce a new public routing API. Ferro still uses one active engine per process. Named databases, replicas, and `using("name")`-style routing are intentionally deferred. + +## Supported Backends + +Ferro currently treats these URL schemes as first-class runtime targets: + +```python +await connect("sqlite:app.db?mode=rwc") +await connect("sqlite::memory:") +await connect("postgresql://user:password@localhost:5432/app") +await connect("postgres://user:password@localhost:5432/app") +``` + +Unsupported schemes fail during connection setup: + +```python +await connect("mysql://user:password@localhost/app") +# raises a connection error: supported schemes are sqlite, postgres, postgresql +``` + +The important implementation detail is that URL detection happens once during `connect()`. After that, the active `EngineHandle` carries the backend kind and typed pool, so operations do not need to rediscover the database from global state or URL strings. + +## Connection Lifecycle + +`ferro.connect()` is the public entry point. Internally, the Rust connection layer does four things: + +1. Splits Ferro-only query parameters from the database URL. +2. Classifies the backend from the URL scheme. +3. Creates a typed SQLx pool for that backend. +4. Stores an `Arc` in global engine state. + +SQLite uses `SqlitePoolOptions` and PostgreSQL uses `PgPoolOptions`. Both currently use a fixed pool size of 5 connections. + +```text +connect(url, auto_migrate) + -> split ferro_search_path + -> BackendKind::from_url(url) + -> connect typed pool + -> optionally create tables + -> store EngineHandle globally +``` + +### PostgreSQL Search Paths + +Ferro supports a private `ferro_search_path` URL parameter for test isolation: + +```python +await connect( + "postgresql://localhost/ferro?ferro_search_path=ferro_test_schema", + auto_migrate=True, +) +``` + +The parameter is removed before SQLx connects. If present, Ferro installs an `after_connect` hook that runs: + +```sql +SET search_path TO ferro_test_schema +``` + +Search path names must be ASCII alphanumeric or `_`. This keeps the test helper ergonomic without allowing arbitrary SQL in the connection URL. + +Use this when several test runs share one PostgreSQL database, but each test should see its own tables. Instead of creating and dropping a whole database for every test, create a temporary schema, connect with that schema as the search path, and let `auto_migrate=True` create the model tables there: + +```python +import uuid + +import psycopg +from ferro import connect, reset_engine + + +async def run_isolated_postgres_test(base_url: str): + schema_name = f"ferro_{uuid.uuid4().hex[:16]}" + + with psycopg.connect(base_url, autocommit=True) as conn: + conn.execute(f'CREATE SCHEMA "{schema_name}"') + + try: + await connect( + f"{base_url}?ferro_search_path={schema_name}", + auto_migrate=True, + ) + + # Test code now reads and writes tables in only this schema. + # A second test can use the same database with a different schema. + finally: + reset_engine() + with psycopg.connect(base_url, autocommit=True) as conn: + conn.execute(f'DROP SCHEMA IF EXISTS "{schema_name}" CASCADE') +``` + +This is how Ferro's PostgreSQL matrix keeps tests isolated while still supporting both local `pytest-postgresql` databases and externally managed databases such as Supabase. + +## Typed Engine Internals + +The core backend types live in `src/backend.rs`. + +```text +BackendKind + Sqlite + Postgres + +EngineHandle + backend: BackendKind + pool: BackendPool + +BackendPool + Sqlite(Arc) + Postgres(Arc) + +EngineConnection + Sqlite(PoolConnection) + Postgres(PoolConnection) +``` + +This replaced the old `sqlx::Any`-centered execution path. Instead of one generic pool that tries to behave like every database, Ferro stores exactly the pool it connected: + +- SQLite connections are executed through SQLx's SQLite driver. +- PostgreSQL connections are executed through SQLx's PostgreSQL driver. +- Transaction connections keep the same typed distinction. +- Backend dispatch is a small enum match at the boundary where SQL actually runs. + +This gives Ferro access to backend-specific SQLx behavior without making the Python API backend-specific. + +## Query And Mutation Execution + +Most ORM operations follow the same high-level pipeline: + +```text +Python Query / Model API + -> JSON query or mutation payload + -> Rust operation function + -> SeaQuery statement + -> backend-specific SQL builder + -> EngineBindValue list + -> EngineHandle or EngineConnection execution + -> EngineRow values + -> RustValue values + -> Python model instances +``` + +SeaQuery remains the SQL construction layer. The backend controls which SeaQuery builder lowers the statement: + +- SQLite uses `SqliteQueryBuilder` +- PostgreSQL uses `PostgresQueryBuilder` + +Bind values are converted into a backend-neutral Ferro enum before execution: + +```text +EngineBindValue + Bool + I64 + F64 + String + Bytes + Null +``` + +The backend then binds those values to the typed SQLx query. This keeps most operation code independent of SQLx's generic types, while still executing through real SQLite or PostgreSQL drivers. + +### Reads + +Read operations fetch typed rows through the engine, materialize each SQLx row into `EngineRow`, then convert the values into Ferro's internal `RustValue` representation. `RustValue` is the final GIL-free representation before Python objects are created. + +This split matters because database values are not the same as Python field values. For example: + +- a PostgreSQL `integer` may decode as `i32`, but Ferro model IDs use Python `int` +- PostgreSQL UUIDs are selected as text before becoming Python `uuid.UUID` +- Decimal values are selected as text before becoming Python `Decimal` +- JSON values are selected as text before becoming Python dicts or lists + +### Writes + +Create, update, relationship, and delete operations build SeaQuery statements and execute them through either: + +- the active `EngineHandle`, if no transaction is active +- the transaction's `EngineConnection`, if a transaction ID is present + +SQLite insert results can report `last_insert_rowid()`. PostgreSQL insert paths rely on explicit `RETURNING` where Ferro needs generated values. + +## Schema Metadata And DDL + +The backend depends on normalized schema metadata from Python. `src/ferro/schema_metadata.py` enriches Pydantic's JSON schema with Ferro-specific keys before Rust consumes it. + +Important metadata includes: + +- `primary_key` +- `autoincrement` +- `unique` +- `index` +- `foreign_key` +- `ferro_nullable` +- `format: "decimal"` +- `enum_type_name` + +That metadata is shared by: + +- Rust runtime DDL in `src/schema.rs` +- Alembic metadata generation in `src/ferro/migrations/alembic.py` +- query and mutation casting decisions in `src/operations.rs` +- relationship join-table generation in `src/ferro/relations/__init__.py` + +The goal is to make the Python schema the contract. Runtime DDL and Alembic may lower it differently, but they should not infer conflicting meanings from the same model. + +### Auto-Migration + +When `auto_migrate=True`, `connect()` creates the typed engine first, then asks Rust to create tables for all registered models. + +```python +await connect("sqlite:dev.db?mode=rwc", auto_migrate=True) +await connect("postgresql://localhost/ferro", auto_migrate=True) +``` + +Runtime DDL uses the active backend: + +- SQLite gets SQLite-compatible column definitions and index SQL. +- PostgreSQL gets PostgreSQL-compatible column definitions, native casts, and SQL syntax. + +## Type Handling Across SQLite And Postgres + +SQLite and PostgreSQL do not store or decode every logical type the same way. Ferro's backend layer aims to preserve the Python model contract while allowing backend-specific SQL where needed. + +### Integer Primary Keys + +SQLite autoincrement IDs come from `last_insert_rowid()`. PostgreSQL `SERIAL` / integer values may decode as `i32`; Ferro materializes them as `i64` and then Python `int`. + +### UUID + +UUIDs are a bridge-boundary type. They can appear as: + +- Python `uuid.UUID` +- JSON query payload strings +- SQL bind values +- PostgreSQL `uuid` columns +- SQLite text-like columns + +Ferro serializes UUIDs before JSON query payloads cross the Python/Rust boundary. For PostgreSQL SQL expressions, Ferro adds explicit `uuid` casts where SQLx or PostgreSQL would otherwise see text. Many-to-many add, remove, and clear operations use the same backend-aware cast path for UUID join-table columns. + +### Decimal + +Python `Decimal` fields are marked with `format: "decimal"` in schema metadata. PostgreSQL can use numeric storage, while SQLite remains more flexible. On reads, Ferro selects Decimal values as text when needed so Python can reconstruct an exact `Decimal`. + +### JSON Objects And Arrays + +Python `dict` and `list` fields are represented as JSON object or array schema types. PostgreSQL writes cast JSON strings to `json` so inserts and updates target native JSON columns correctly. Reads select JSON values as text when required, then parse them back into Python values. + +### Dates And Datetimes + +Temporal values cross the bridge as ISO strings and are reconstructed into Python `date` or `datetime` objects. PostgreSQL SQL generation applies explicit casts for temporal comparisons and nulls where needed. + +### Enums + +Enums are represented through schema metadata, including the enum type name. PostgreSQL-specific enum casts are applied where the column uses a native enum type. Portable text-like enum behavior remains available through the same Python model shape. + +## Transactions + +Transactions use the same typed backend model as normal operations. + +When a root transaction begins: + +```text +active EngineHandle + -> acquire typed pool connection + -> BEGIN + -> TransactionHandle::root(EngineConnection) +``` + +Nested transactions reuse the same typed connection and create savepoints: + +```text +parent TransactionConnection + -> SAVEPOINT sp_ + -> TransactionHandle::nested(parent_conn, savepoint_name) +``` + +The transaction registry stores a transaction ID mapped to: + +- a shared `Arc>` +- an optional savepoint name + +This means all operations inside a transaction execute on the same typed database connection. Commit and rollback dispatch through the `EngineConnection` enum, not through a generic SQLx connection. + +## Testing The Backend Matrix + +Backend correctness is tested with the same public API users call. Tests that should run on both databases use the backend matrix fixtures: + +```python +@pytest.mark.backend_matrix +async def test_create_and_fetch(db_url): + await connect(db_url, auto_migrate=True) + ... +``` + +Run the SQLite default suite: + +```bash +uv run pytest -q +``` + +Run the SQLite/PostgreSQL matrix: + +```bash +uv run pytest -m "backend_matrix or postgres_only" --db-backends=sqlite,postgres -q +``` + +Run only the PostgreSQL side: + +```bash +uv run pytest -m "backend_matrix or postgres_only" --db-backends=postgres -q +``` + +### Local PostgreSQL Provider + +The test harness supports local ephemeral PostgreSQL through `pytest-postgresql`. + +Install PostgreSQL server binaries, then force the local provider: + +```bash +brew install postgresql@16 +FERRO_POSTGRES_PROVIDER=local uv run pytest -m "backend_matrix or postgres_only" --db-backends=postgres -q +``` + +If `FERRO_POSTGRES_PROVIDER=local` is not set, tests prefer an external URL: + +1. `FERRO_POSTGRES_URL` +2. legacy `FERRO_SUPABASE_URL` +3. local `pytest-postgresql` fallback + +Each PostgreSQL test gets an isolated schema through `ferro_search_path`, so externally managed databases can still run isolated test cases. + +## How To Extend This Later + +The current backend design makes a future backend, such as MySQL, more approachable but not automatic. A new backend would need: + +1. A new `BackendKind` variant. +2. A typed SQLx pool and connection variant. +3. URL classification. +4. SeaQuery builder dispatch. +5. DDL type mapping in `src/schema.rs`. +6. bind and row materialization support in `src/backend.rs`. +7. schema-value casting rules in `src/operations.rs`. +8. backend-matrix test coverage. +9. docs that clearly state support level and known differences. + +Avoid adding a backend by sprinkling one-off branches through query, schema, and operation code. The maintainable path is to make the backend identity explicit first, then lower shared ORM semantics through that backend. + +## Troubleshooting And Gotchas + +### `Engine not initialized` + +You called a model or query method before `await connect(...)`. Importing models registers schema, but it does not connect to the database. + +### Unsupported URL scheme + +Only `sqlite:`, `postgres://`, and `postgresql://` are supported. MySQL is planned for later, not accepted by this backend. + +### PostgreSQL tests use the wrong database + +If `.env` contains `FERRO_POSTGRES_URL` or `FERRO_SUPABASE_URL`, the test harness will use it by default. Set `FERRO_POSTGRES_PROVIDER=local` to force `pytest-postgresql`. + +### Local PostgreSQL tests skip or fail to start + +`pytest-postgresql` needs server binaries such as `pg_ctl`, `postgres`, and `initdb` on `PATH`. On macOS with Homebrew, installing `postgresql@16` usually provides them. + +### UUID or Decimal values fail only on PostgreSQL + +Check whether the value crosses the Python/Rust boundary as JSON or as a direct PyO3 argument. Query payloads must serialize non-JSON-native Python values before `json.dumps`; direct relationship operations must preserve typed values long enough for backend-aware SQL casts. + +### Runtime DDL and Alembic disagree + +Start with schema metadata. If `ferro_nullable`, `format`, `primary_key`, or relationship metadata is missing from the normalized Python schema, Rust DDL and Alembic may lower the same model differently. Fix the metadata source before adding more backend-specific lowering rules. + +## Mental Model + +The shortest way to understand the backend is: + +```text +Python owns the model contract. +Rust owns execution. +SeaQuery owns SQL shape. +SQLx owns typed database I/O. +BackendKind decides which database-specific path is legal. +``` + +When changing backend behavior, preserve that separation. Put shared ORM meaning in schema/query metadata, then make the backend choose the correct SQLite or PostgreSQL lowering at the execution boundary. diff --git a/docs/howto/testing.md b/docs/howto/testing.md index f9a5a06..7dce2d3 100644 --- a/docs/howto/testing.md +++ b/docs/howto/testing.md @@ -7,7 +7,7 @@ Test your Ferro applications with pytest and test database isolation strategies. The repository test suite supports two database modes: - **Default SQLite run** for the full fast suite -- **Dual-backend matrix** for ORM coverage on both SQLite and PostgreSQL/Supabase +- **Dual-backend matrix** for ORM coverage on both SQLite and PostgreSQL The matrix is opt-in so day-to-day test runs stay quick and deterministic. @@ -20,13 +20,25 @@ uv sync --group dev uv run maturin develop ``` -Set `FERRO_SUPABASE_URL` to a PostgreSQL connection string. A root `.env` file works well for local development: +For local PostgreSQL matrix runs, install PostgreSQL server binaries so `pytest-postgresql` can start an ephemeral database: ```bash -FERRO_SUPABASE_URL='postgresql://...' +brew install postgresql@16 ``` -The Postgres matrix reads `FERRO_SUPABASE_URL` from either the environment or the project `.env` file. Tests create a dedicated schema per test and use that schema as the search path so one shared Supabase database can still run isolated tests safely. +You can also point the suite at an externally managed PostgreSQL database. A root `.env` file works well for local development: + +```bash +FERRO_POSTGRES_URL='postgresql://...' +``` + +The Postgres matrix first reads `FERRO_POSTGRES_URL` from either the environment or the project `.env` file. It still accepts the older `FERRO_SUPABASE_URL` name as a compatibility fallback. Tests create a dedicated schema per test and use that schema as the search path so one shared external database can still run isolated tests safely. + +To force the local `pytest-postgresql` provider even when `.env` contains an external URL: + +```bash +FERRO_POSTGRES_PROVIDER=local uv run pytest -m "backend_matrix or postgres_only" --db-backends=postgres -q +``` ### Run The Default Suite @@ -56,9 +68,22 @@ The repository uses three database markers: - `backend_matrix`: run this test once per selected backend - `sqlite_only`: keep SQLite-specific catalog, file-path, or pragma assertions on SQLite -- `postgres_only`: run Postgres/Supabase-specific assertions only when `FERRO_SUPABASE_URL` is configured +- `postgres_only`: run Postgres-specific assertions when either an external Postgres URL is configured or `pytest-postgresql` can start a local server + +If no external Postgres URL is set and local PostgreSQL server binaries are unavailable, `postgres_only` tests are skipped and `backend_matrix` tests run only on SQLite. + +### Bridge-Boundary Regressions + +When a bug involves values crossing the Python/Rust bridge, preserve the public API shape in the regression test. These issues often depend on whether a value travels as JSON (`Query.all()`, `Query.count()`, `Query.update()`, `Query.delete()`) or as a typed Python value passed directly to Rust (`ManyToManyField.add()`, `.remove()`, `.clear()`). + +Use these conventions: -If `FERRO_SUPABASE_URL` is not set, `postgres_only` tests are skipped and `backend_matrix` tests run only on SQLite. +- Put relationship and auto-migration regressions in `tests/test_auto_migrate.py` when they strengthen the backend matrix. +- Put structural type regressions in `tests/test_structural_types.py` when they involve UUID, Decimal, JSON, enum, binary, date, or datetime behavior. +- Use `backend_matrix` when the public behavior should work on both SQLite and PostgreSQL. +- Use `postgres_only` when the assertion depends on native PostgreSQL types, catalogs, or casts. +- Convert user repro scripts with minimal translation: keep the same model shape and public method sequence, trim incidental setup, and assert the original failure mode is gone. +- Add a fast serializer or static-contract test when the bug is caused by a Python boundary rule, such as raw `json.dumps(query_def)` bypassing Ferro's query serializer. ## Basic Setup @@ -83,7 +108,7 @@ async def db_transaction(db): yield ``` -For backend-matrix tests, Ferro's own suite uses `--db-backends=sqlite,postgres` together with `backend_matrix` / `postgres_only` markers and a `FERRO_SUPABASE_URL` environment variable. +For backend-matrix tests, Ferro's own suite uses `--db-backends=sqlite,postgres` together with `backend_matrix` / `postgres_only` markers. Postgres coverage uses `pytest-postgresql` locally, or `FERRO_POSTGRES_URL` / `FERRO_SUPABASE_URL` when an external database is configured. ## Test Example diff --git a/mkdocs.yml b/mkdocs.yml index 51f1f88..921b156 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -54,6 +54,7 @@ plugins: - guide/queries.md - guide/mutations.md - guide/transactions.md + - guide/backend.md - guide/database.md - guide/migrations.md How-To: @@ -130,6 +131,7 @@ nav: - Queries: guide/queries.md - Mutations: guide/mutations.md - Transactions: guide/transactions.md + - Backend: guide/backend.md - Database Setup: guide/database.md - Schema Management: guide/migrations.md diff --git a/pyproject.toml b/pyproject.toml index 2d7296a..7acbedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ ci-test = [ "pytest-asyncio>=0.23.0", "pytest-cov>=7.0.0", "pytest-examples>=0.0.18", + "pytest-postgresql>=8.0.0", ] docs = [ "mkdocs-material>=9.5.0", @@ -71,6 +72,7 @@ dev = [ "pymdown-extensions>=10.7.0", "pytest-examples>=0.0.18", "psycopg[binary]>=3.3.3", + "pytest-postgresql>=8.0.0", ] [tool.pytest.ini_options] diff --git a/src/backend.rs b/src/backend.rs index 37e5ca1..03699f7 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,4 +1,6 @@ -use sqlx::{Any, Pool}; +use sqlx::ColumnIndex; +use sqlx::pool::PoolConnection; +use sqlx::{Column, PgPool, Postgres, Row, Sqlite, SqlitePool}; use std::fmt; use std::sync::Arc; @@ -27,14 +29,74 @@ impl BackendKind { #[derive(Clone, Debug)] pub struct EngineHandle { backend: BackendKind, - pool: Arc>, + pool: BackendPool, +} + +#[derive(Clone, Debug)] +enum BackendPool { + Sqlite(Arc), + Postgres(Arc), +} + +#[allow(dead_code)] +pub enum EngineConnection { + Sqlite(PoolConnection), + Postgres(PoolConnection), +} + +#[derive(Clone, Debug, PartialEq)] +pub enum EngineBindValue { + Bool(bool), + I64(i64), + F64(f64), + String(String), + Bytes(Vec), + Null, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct EngineRow { + pub values: Vec<(String, EngineValue)>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct EngineExecuteResult { + pub rows_affected: u64, + pub last_insert_id: Option, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum EngineValue { + Bool(bool), + I64(i64), + F64(f64), + String(String), + Bytes(Vec), + Null, +} + +impl EngineValue { + pub fn as_i64(&self) -> Option { + match self { + Self::I64(value) => Some(*value), + Self::String(value) => value.parse().ok(), + _ => None, + } + } } impl EngineHandle { - pub fn new(backend: BackendKind, pool: Pool) -> Self { + pub fn new_sqlite(pool: SqlitePool) -> Self { + Self { + backend: BackendKind::Sqlite, + pool: BackendPool::Sqlite(Arc::new(pool)), + } + } + + pub fn new_postgres(pool: PgPool) -> Self { Self { - backend, - pool: Arc::new(pool), + backend: BackendKind::Postgres, + pool: BackendPool::Postgres(Arc::new(pool)), } } @@ -42,8 +104,261 @@ impl EngineHandle { self.backend } - pub fn pool(&self) -> Arc> { - self.pool.clone() + #[allow(dead_code)] + pub fn sqlite_pool(&self) -> Option> { + match &self.pool { + BackendPool::Sqlite(pool) => Some(pool.clone()), + BackendPool::Postgres(_) => None, + } + } + + #[allow(dead_code)] + pub fn postgres_pool(&self) -> Option> { + match &self.pool { + BackendPool::Postgres(pool) => Some(pool.clone()), + BackendPool::Sqlite(_) => None, + } + } + + pub async fn execute_sql(&self, sql: &str) -> Result { + match &self.pool { + BackendPool::Sqlite(pool) => { + let result = sqlx::query(sql).execute(pool.as_ref()).await?; + Ok(result.rows_affected()) + } + BackendPool::Postgres(pool) => { + let result = sqlx::query(sql).execute(pool.as_ref()).await?; + Ok(result.rows_affected()) + } + } + } + + pub async fn execute_sql_with_binds( + &self, + sql: &str, + values: &[EngineBindValue], + ) -> Result { + Ok(self + .execute_sql_with_binds_result(sql, values) + .await? + .rows_affected) + } + + pub async fn execute_sql_with_binds_result( + &self, + sql: &str, + values: &[EngineBindValue], + ) -> Result { + match &self.pool { + BackendPool::Sqlite(pool) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let result = query.execute(pool.as_ref()).await?; + Ok(EngineExecuteResult { + rows_affected: result.rows_affected(), + last_insert_id: Some(result.last_insert_rowid()), + }) + } + BackendPool::Postgres(pool) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let result = query.execute(pool.as_ref()).await?; + Ok(EngineExecuteResult { + rows_affected: result.rows_affected(), + last_insert_id: None, + }) + } + } + } + + pub async fn fetch_all_sql_with_binds( + &self, + sql: &str, + values: &[EngineBindValue], + ) -> Result, sqlx::Error> { + match &self.pool { + BackendPool::Sqlite(pool) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let rows = query.fetch_all(pool.as_ref()).await?; + Ok(rows.iter().map(materialize_engine_row).collect()) + } + BackendPool::Postgres(pool) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let rows = query.fetch_all(pool.as_ref()).await?; + Ok(rows.iter().map(materialize_engine_row).collect()) + } + } + } + + #[allow(dead_code)] + pub async fn begin_transaction_connection(&self) -> Result { + match &self.pool { + BackendPool::Sqlite(pool) => { + let mut conn = pool.acquire().await?; + sqlx::query("BEGIN").execute(&mut *conn).await?; + Ok(EngineConnection::Sqlite(conn)) + } + BackendPool::Postgres(pool) => { + let mut conn = pool.acquire().await?; + sqlx::query("BEGIN").execute(&mut *conn).await?; + Ok(EngineConnection::Postgres(conn)) + } + } + } +} + +#[allow(dead_code)] +impl EngineConnection { + pub async fn execute_sql(&mut self, sql: &str) -> Result { + self.execute_sql_with_binds(sql, &[]).await + } + + pub async fn execute_sql_with_binds( + &mut self, + sql: &str, + values: &[EngineBindValue], + ) -> Result { + Ok(self + .execute_sql_with_binds_result(sql, values) + .await? + .rows_affected) + } + + pub async fn execute_sql_with_binds_result( + &mut self, + sql: &str, + values: &[EngineBindValue], + ) -> Result { + match self { + EngineConnection::Sqlite(conn) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let result = query.execute(&mut **conn).await?; + Ok(EngineExecuteResult { + rows_affected: result.rows_affected(), + last_insert_id: Some(result.last_insert_rowid()), + }) + } + EngineConnection::Postgres(conn) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let result = query.execute(&mut **conn).await?; + Ok(EngineExecuteResult { + rows_affected: result.rows_affected(), + last_insert_id: None, + }) + } + } + } + + pub async fn fetch_all_sql_with_binds( + &mut self, + sql: &str, + values: &[EngineBindValue], + ) -> Result, sqlx::Error> { + match self { + EngineConnection::Sqlite(conn) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let rows = query.fetch_all(&mut **conn).await?; + Ok(rows.iter().map(materialize_engine_row).collect()) + } + EngineConnection::Postgres(conn) => { + let mut query = sqlx::query(sql); + for value in values { + query = bind_engine_value(query, value); + } + let rows = query.fetch_all(&mut **conn).await?; + Ok(rows.iter().map(materialize_engine_row).collect()) + } + } + } + + pub async fn commit(&mut self) -> Result<(), sqlx::Error> { + self.execute_sql("COMMIT").await?; + Ok(()) + } + + pub async fn rollback(&mut self) -> Result<(), sqlx::Error> { + self.execute_sql("ROLLBACK").await?; + Ok(()) + } +} + +fn materialize_engine_row(row: &R) -> EngineRow +where + R: Row, + for<'r> i32: sqlx::Decode<'r, R::Database> + sqlx::Type, + for<'r> i64: sqlx::Decode<'r, R::Database> + sqlx::Type, + for<'r> f64: sqlx::Decode<'r, R::Database> + sqlx::Type, + for<'r> Vec: sqlx::Decode<'r, R::Database> + sqlx::Type, + for<'r> String: sqlx::Decode<'r, R::Database> + sqlx::Type, + for<'r> bool: sqlx::Decode<'r, R::Database> + sqlx::Type, + usize: ColumnIndex, +{ + let values = row + .columns() + .iter() + .map(|column| { + let name = column.name().to_string(); + let value = if let Ok(value) = row.try_get::(column.ordinal()) { + EngineValue::I64(value) + } else if let Ok(value) = row.try_get::(column.ordinal()) { + EngineValue::I64(i64::from(value)) + } else if let Ok(value) = row.try_get::(column.ordinal()) { + EngineValue::F64(value) + } else if let Ok(value) = row.try_get::(column.ordinal()) { + EngineValue::String(value) + } else if let Ok(value) = row.try_get::, _>(column.ordinal()) { + EngineValue::Bytes(value) + } else if let Ok(value) = row.try_get::(column.ordinal()) { + EngineValue::Bool(value) + } else { + EngineValue::Null + }; + (name, value) + }) + .collect(); + + EngineRow { values } +} + +fn bind_engine_value<'q, DB>( + query: sqlx::query::Query<'q, DB, ::Arguments<'q>>, + value: &'q EngineBindValue, +) -> sqlx::query::Query<'q, DB, ::Arguments<'q>> +where + DB: sqlx::Database, + bool: sqlx::Encode<'q, DB> + sqlx::Type, + i64: sqlx::Encode<'q, DB> + sqlx::Type, + f64: sqlx::Encode<'q, DB> + sqlx::Type, + String: sqlx::Encode<'q, DB> + sqlx::Type, + Vec: sqlx::Encode<'q, DB> + sqlx::Type, + Option: sqlx::Encode<'q, DB> + sqlx::Type, +{ + match value { + EngineBindValue::Bool(v) => query.bind(*v), + EngineBindValue::I64(v) => query.bind(*v), + EngineBindValue::F64(v) => query.bind(*v), + EngineBindValue::String(v) => query.bind(v.clone()), + EngineBindValue::Bytes(v) => query.bind(v.clone()), + EngineBindValue::Null => query.bind(Option::::None), } } @@ -72,6 +387,11 @@ impl fmt::Display for UnsupportedDatabaseUrl { #[cfg(test)] mod tests { use super::BackendKind; + use super::EngineBindValue; + use super::EngineHandle; + use super::EngineValue; + use sqlx::postgres::PgPoolOptions; + use sqlx::sqlite::SqlitePoolOptions; #[test] fn classifies_sqlite_urls() { @@ -105,4 +425,226 @@ mod tests { "Unsupported database URL scheme 'mysql'. Supported schemes: sqlite, postgres, postgresql" ); } + + #[tokio::test] + async fn engine_handle_preserves_typed_sqlite_pool() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + + let engine = EngineHandle::new_sqlite(pool); + + assert_eq!(engine.backend(), BackendKind::Sqlite); + assert!(engine.sqlite_pool().is_some()); + assert!(engine.postgres_pool().is_none()); + } + + #[tokio::test] + async fn engine_handle_preserves_typed_postgres_pool() { + let pool = PgPoolOptions::new() + .max_connections(1) + .connect_lazy("postgresql://example.invalid/postgres") + .unwrap(); + + let engine = EngineHandle::new_postgres(pool); + + assert_eq!(engine.backend(), BackendKind::Postgres); + assert!(engine.sqlite_pool().is_none()); + assert!(engine.postgres_pool().is_some()); + } + + #[tokio::test] + async fn engine_handle_executes_sqlite_sql_without_legacy_pool() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + let engine = EngineHandle::new_sqlite(pool); + + engine + .execute_sql("CREATE TABLE typed_exec_check (id integer primary key)") + .await + .unwrap(); + engine + .execute_sql("INSERT INTO typed_exec_check (id) VALUES (1)") + .await + .unwrap(); + + assert_eq!( + engine + .execute_sql("UPDATE typed_exec_check SET id = 2") + .await + .unwrap(), + 1 + ); + } + + #[tokio::test] + async fn engine_handle_executes_sqlite_sql_with_bound_values() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + let engine = EngineHandle::new_sqlite(pool); + + engine + .execute_sql("CREATE TABLE typed_bind_check (id integer primary key, name text)") + .await + .unwrap(); + + let inserted = engine + .execute_sql_with_binds( + "INSERT INTO typed_bind_check (id, name) VALUES (?, ?)", + &[ + EngineBindValue::I64(7), + EngineBindValue::String("ferro".to_string()), + ], + ) + .await + .unwrap(); + + assert_eq!(inserted, 1); + } + + #[tokio::test] + async fn engine_handle_fetches_sqlite_rows_with_bound_values() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + let engine = EngineHandle::new_sqlite(pool); + + engine + .execute_sql("CREATE TABLE typed_fetch_check (id integer primary key, name text)") + .await + .unwrap(); + engine + .execute_sql_with_binds( + "INSERT INTO typed_fetch_check (id, name) VALUES (?, ?)", + &[ + EngineBindValue::I64(7), + EngineBindValue::String("ferro".to_string()), + ], + ) + .await + .unwrap(); + + let rows = engine + .fetch_all_sql_with_binds( + "SELECT id, name FROM typed_fetch_check WHERE id = ?", + &[EngineBindValue::I64(7)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].values[0], ("id".to_string(), EngineValue::I64(7))); + assert_eq!( + rows[0].values[1], + ("name".to_string(), EngineValue::String("ferro".to_string())) + ); + } + + #[tokio::test] + async fn engine_handle_execute_result_includes_sqlite_last_insert_id() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + let engine = EngineHandle::new_sqlite(pool); + + engine + .execute_sql("CREATE TABLE typed_insert_result (id integer primary key, name text)") + .await + .unwrap(); + + let result = engine + .execute_sql_with_binds_result( + "INSERT INTO typed_insert_result (name) VALUES (?)", + &[EngineBindValue::String("ferro".to_string())], + ) + .await + .unwrap(); + + assert_eq!(result.rows_affected, 1); + assert_eq!(result.last_insert_id, Some(1)); + } + + #[tokio::test] + async fn engine_handle_commits_sqlite_transaction_connection() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + let engine = EngineHandle::new_sqlite(pool); + + engine + .execute_sql("CREATE TABLE typed_tx_check (id integer primary key, name text)") + .await + .unwrap(); + + let mut tx = engine.begin_transaction_connection().await.unwrap(); + tx.execute_sql_with_binds( + "INSERT INTO typed_tx_check (name) VALUES (?)", + &[EngineBindValue::String("ferro".to_string())], + ) + .await + .unwrap(); + tx.commit().await.unwrap(); + drop(tx); + + let rows = engine + .fetch_all_sql_with_binds("SELECT name FROM typed_tx_check", &[]) + .await + .unwrap(); + assert_eq!( + rows[0].values[0], + ("name".to_string(), EngineValue::String("ferro".to_string())) + ); + } + + #[tokio::test] + async fn engine_handle_rolls_back_sqlite_transaction_connection() { + let pool = SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap(); + let engine = EngineHandle::new_sqlite(pool); + + engine + .execute_sql("CREATE TABLE typed_tx_rollback_check (id integer primary key, name text)") + .await + .unwrap(); + + let mut tx = engine.begin_transaction_connection().await.unwrap(); + tx.execute_sql_with_binds( + "INSERT INTO typed_tx_rollback_check (name) VALUES (?)", + &[EngineBindValue::String("ferro".to_string())], + ) + .await + .unwrap(); + tx.rollback().await.unwrap(); + drop(tx); + + let rows = engine + .fetch_all_sql_with_binds("SELECT name FROM typed_tx_rollback_check", &[]) + .await + .unwrap(); + assert!(rows.is_empty()); + } + + #[test] + fn engine_value_converts_integer_like_values_to_i64() { + assert_eq!(EngineValue::I64(42).as_i64(), Some(42)); + assert_eq!(EngineValue::String("42".to_string()).as_i64(), Some(42)); + assert_eq!(EngineValue::Null.as_i64(), None); + } } diff --git a/src/connection.rs b/src/connection.rs index 648ea67..97cc750 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -7,7 +7,8 @@ use crate::backend::{BackendKind, EngineHandle}; use crate::schema::internal_create_tables; use crate::state::{ENGINE, IDENTITY_MAP}; use pyo3::prelude::*; -use sqlx::any::AnyPoolOptions; +use sqlx::postgres::PgPoolOptions; +use sqlx::sqlite::SqlitePoolOptions; use std::sync::Arc; fn split_search_path(url: &str) -> (String, Option) { @@ -42,6 +43,40 @@ fn is_safe_search_path(search_path: &str) -> bool { .all(|ch| ch.is_ascii_alphanumeric() || ch == '_') } +async fn connect_engine_handle( + connection_url: &str, + backend: BackendKind, + search_path: Option, +) -> Result { + match backend { + BackendKind::Sqlite => { + let pool = SqlitePoolOptions::new() + .max_connections(5) + .connect(connection_url) + .await?; + Ok(EngineHandle::new_sqlite(pool)) + } + BackendKind::Postgres => { + let mut pool_options = PgPoolOptions::new().max_connections(5); + if let Some(search_path) = search_path { + let set_search_path_sql = Arc::new(format!("SET search_path TO {}", search_path)); + pool_options = pool_options.after_connect(move |conn, _meta| { + let set_search_path_sql = set_search_path_sql.clone(); + Box::pin(async move { + sqlx::query(set_search_path_sql.as_str()) + .execute(conn) + .await?; + Ok(()) + }) + }); + } + + let pool = pool_options.connect(connection_url).await?; + Ok(EngineHandle::new_postgres(pool)) + } + } +} + /// Initializes the global database connection pool. /// /// This is an asynchronous function that returns a Python coroutine. @@ -73,31 +108,16 @@ pub fn connect(py: Python<'_>, url: String, auto_migrate: bool) -> PyResult PyResult<()> { IDENTITY_MAP.clear(); Ok(()) } + +#[cfg(test)] +mod tests { + use super::connect_engine_handle; + use crate::backend::BackendKind; + + #[tokio::test] + async fn connect_engine_handle_uses_typed_sqlite_backend() { + let engine = connect_engine_handle("sqlite::memory:", BackendKind::Sqlite, None) + .await + .unwrap(); + + assert_eq!(engine.backend(), BackendKind::Sqlite); + assert!(engine.sqlite_pool().is_some()); + assert!(engine.postgres_pool().is_none()); + } + + #[tokio::test] + async fn connect_engine_handle_supports_sqlite_runtime_execution() { + let engine = connect_engine_handle("sqlite::memory:", BackendKind::Sqlite, None) + .await + .unwrap(); + + assert_eq!(engine.backend(), BackendKind::Sqlite); + assert!(engine.sqlite_pool().is_some()); + assert_eq!(engine.execute_sql("SELECT 1").await.unwrap(), 0); + } +} diff --git a/src/ferro/migrations/alembic.py b/src/ferro/migrations/alembic.py index 42e5f1e..7592aa4 100644 --- a/src/ferro/migrations/alembic.py +++ b/src/ferro/migrations/alembic.py @@ -8,8 +8,6 @@ except ImportError: sa = None -from .._annotation_utils import annotation_allows_none -from ..base import ForeignKey, foreign_key_allows_none from ..schema_metadata import build_model_schema from ..state import _JOIN_TABLE_REGISTRY, _MODEL_REGISTRY_PY @@ -79,7 +77,13 @@ def _resolve_ref(schema: Dict[str, Any], col_info: Dict[str, Any]) -> Dict[str, ref_path = col_info["$ref"] if ref_path.startswith("#/$defs/"): def_name = ref_path.split("/")[-1] - return schema.get("$defs", {}).get(def_name, col_info) + resolved = schema.get("$defs", {}).get(def_name, col_info) + if resolved is col_info: + return col_info + return { + **resolved, + **{k: v for k, v in col_info.items() if k != "$ref"}, + } return col_info def _strip_optional_union(annotation: Any) -> Any: @@ -136,10 +140,7 @@ def _infer_nullable_join_table( def _resolve_sa_column_nullable( - col_name: str, - col_info: Dict[str, Any], - model_cls: type[Any] | None, - required_fields: list[str], + col_name: str, col_info: Dict[str, Any], required_fields: list[str] ) -> bool: """SQLAlchemy ``Column.nullable`` for one table column.""" if col_info.get("primary_key"): @@ -149,25 +150,6 @@ def _resolve_sa_column_nullable( if isinstance(override, bool): return override - if model_cls is not None: - model_fields = getattr(model_cls, "model_fields", None) - fk_info = col_info.get("foreign_key") or {} - if model_fields and fk_info and col_name.endswith("_id"): - rel_name = col_name[:-3] - rels = getattr(model_cls, "ferro_relations", {}) - meta = rels.get(rel_name) - if isinstance(meta, ForeignKey): - fk_nullable = foreign_key_allows_none(meta) - if fk_nullable is not None: - return fk_nullable - rel_field = model_fields.get(rel_name) - if rel_field is not None: - return annotation_allows_none(rel_field.annotation) - if model_fields: - field_info = model_fields.get(col_name) - if field_info is not None: - return annotation_allows_none(field_info.annotation) - return _infer_nullable_join_table(col_name, col_info, required_fields) @@ -190,9 +172,7 @@ def _build_sa_table( python_enum = _field_python_enum(model_cls, col_name) sa_type = _map_to_sa_type(schema, col_info, col_name, python_enum) - is_nullable = _resolve_sa_column_nullable( - col_name, col_info, model_cls, required_fields - ) + is_nullable = _resolve_sa_column_nullable(col_name, col_info, required_fields) fk_info = col_info.get("foreign_key") or {} column_unique = bool(col_info.get("unique")) or bool(fk_info.get("unique")) @@ -261,7 +241,7 @@ def _map_to_sa_type( item = _resolve_ref(schema, item) if item.get("type") != "null": json_type = item.get("type") - format = item.get("format") + format = item.get("format") or format enum_values = item.get("enum") or enum_values break diff --git a/src/ferro/query/builder.py b/src/ferro/query/builder.py index d7babfc..fb2c3c2 100644 --- a/src/ferro/query/builder.py +++ b/src/ferro/query/builder.py @@ -12,6 +12,7 @@ remove_m2m_links, update_filtered, ) +from .nodes import _serialize_query_value if TYPE_CHECKING: from .nodes import QueryNode @@ -19,6 +20,11 @@ T = TypeVar("T") +def _query_def_to_json(query_def: dict[str, Any]) -> str: + """Serialize query definitions while preserving typed values in live Query state.""" + return json.dumps(_serialize_query_value(query_def)) + + class Query(Generic[T]): """Build and execute fluent ORM queries. @@ -158,7 +164,7 @@ async def all(self) -> list[T]: from ..state import _CURRENT_TRANSACTION tx_id = _CURRENT_TRANSACTION.get() - results = await fetch_filtered(self.model_cls, json.dumps(query_def), tx_id) + results = await fetch_filtered(self.model_cls, _query_def_to_json(query_def), tx_id) for instance in results: if hasattr(self.model_cls, "_fix_types"): self.model_cls._fix_types(instance) @@ -184,7 +190,7 @@ async def count(self) -> int: tx_id = _CURRENT_TRANSACTION.get() return await count_filtered( - self.model_cls.__name__, json.dumps(query_def), tx_id + self.model_cls.__name__, _query_def_to_json(query_def), tx_id ) async def update(self, **fields) -> int: @@ -215,7 +221,7 @@ async def update(self, **fields) -> int: # Use pydantic_core.to_json to handle Decimals, UUIDs, etc. in kwargs return await update_filtered( self.model_cls.__name__, - json.dumps(query_def), + _query_def_to_json(query_def), to_json(fields).decode(), tx_id, ) @@ -260,7 +266,7 @@ async def delete(self) -> int: tx_id = _CURRENT_TRANSACTION.get() return await delete_filtered( - self.model_cls.__name__, json.dumps(query_def), tx_id + self.model_cls.__name__, _query_def_to_json(query_def), tx_id ) async def exists(self) -> bool: diff --git a/src/ferro/relations/__init__.py b/src/ferro/relations/__init__.py index a1b0d76..80fa372 100644 --- a/src/ferro/relations/__init__.py +++ b/src/ferro/relations/__init__.py @@ -109,6 +109,7 @@ def resolve_relationships(): "properties": { source_col: { **source_schema, + "ferro_nullable": False, "foreign_key": { "to_table": model_name.lower(), "on_delete": "CASCADE", @@ -116,12 +117,14 @@ def resolve_relationships(): }, target_col: { **target_schema, + "ferro_nullable": False, "foreign_key": { "to_table": target_model.__name__.lower(), "on_delete": "CASCADE", }, }, }, + "required": [source_col, target_col], "ferro_composite_uniques": [[source_col, target_col]], } register_model_schema(join_table, json.dumps(join_schema)) diff --git a/src/ferro/schema_metadata.py b/src/ferro/schema_metadata.py index c3115a7..da41c59 100644 --- a/src/ferro/schema_metadata.py +++ b/src/ferro/schema_metadata.py @@ -3,6 +3,7 @@ from __future__ import annotations import types +from decimal import Decimal from enum import Enum from typing import ( Annotated, @@ -50,6 +51,16 @@ def _enum_subclass_from_annotation(hint: Any) -> type[Enum] | None: return None +def _annotation_is_decimal(hint: Any) -> bool: + hint = _strip_optional_union(hint) + if get_origin(hint) is Annotated: + args = get_args(hint) + if args: + return _annotation_is_decimal(args[0]) + return False + return hint is Decimal + + def _target_table_name(target: Any) -> str: if isinstance(target, ForwardRef): return target.__forward_arg__.lower() @@ -126,7 +137,12 @@ def build_model_schema( for field_name, finfo in model_fields.items(): if field_name not in properties or not isinstance(properties[field_name], dict): continue + properties[field_name].setdefault( + "ferro_nullable", annotation_allows_none(finfo.annotation) + ) ann_hint = resolved_annotations.get(field_name, finfo.annotation) + if _annotation_is_decimal(ann_hint): + properties[field_name]["format"] = "decimal" enum_cls = _enum_subclass_from_annotation(ann_hint) if enum_cls is not None: properties[field_name]["enum_type_name"] = enum_cls.__name__.lower() diff --git a/src/operations.rs b/src/operations.rs index 5ed4aca..0939c78 100644 --- a/src/operations.rs +++ b/src/operations.rs @@ -3,132 +3,241 @@ //! This module implements high-performance CRUD operations, leveraging //! GIL-free parsing and zero-copy Direct Injection into Python objects. +use crate::backend::{EngineBindValue, EngineHandle, EngineRow, EngineValue}; use crate::query::QueryDef; use crate::state::{ - engine_pool, IDENTITY_MAP, MODEL_REGISTRY, RustValue, TRANSACTION_REGISTRY, TransactionHandle, + IDENTITY_MAP, MODEL_REGISTRY, RustValue, SqlDialect, TRANSACTION_REGISTRY, + TransactionConnection, TransactionHandle, engine_handle, }; use pyo3::prelude::*; use sea_query::{ - Alias, Expr, Iden, InsertStatement, OnConflict, Order, PostgresQueryBuilder, Query, - SimpleExpr, SqliteQueryBuilder, UpdateStatement, Value as SeaValue, + Alias, Expr, Iden, InsertStatement, OnConflict, Order, PostgresQueryBuilder, Query, SimpleExpr, + SqliteQueryBuilder, UpdateStatement, Value as SeaValue, }; -use sqlx::{Any, AnyConnection, Column, Pool, Row}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use tokio::sync::Mutex; - -macro_rules! get_conn { - ($pool:expr, $tx_id:expr) => {{ - if let Some(tx_id) = $tx_id { - if let Some(tx_handle) = TRANSACTION_REGISTRY.get(&tx_id) { - let conn = tx_handle.value().conn.clone(); - Some(conn) - } else { - None - } - } else { - None - } - }}; + +fn get_transaction_connection(tx_id: Option) -> Option { + tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }) } -/// Build SQL with the dialect set at `connect()` time (`?` for SQLite, `$n` for Postgres). -macro_rules! sea_query_build { - ($stmt:expr) => {{ - match crate::state::sql_dialect() { - crate::state::SqlDialect::Sqlite => $stmt.build(SqliteQueryBuilder), - crate::state::SqlDialect::Postgres => $stmt.build(PostgresQueryBuilder), - } - }}; +fn active_engine() -> PyResult> { + let engine = engine_handle() + .ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized"))?; + Ok(engine) } -macro_rules! sea_query_to_string { - ($stmt:expr) => {{ - match crate::state::sql_dialect() { - crate::state::SqlDialect::Sqlite => $stmt.to_string(SqliteQueryBuilder), - crate::state::SqlDialect::Postgres => $stmt.to_string(PostgresQueryBuilder), +fn engine_bind_values_from_sea(values: &[SeaValue]) -> Vec { + values + .iter() + .map(|val| match val { + SeaValue::Bool(Some(b)) => EngineBindValue::Bool(*b), + SeaValue::TinyInt(Some(i)) => EngineBindValue::I64(*i as i64), + SeaValue::SmallInt(Some(i)) => EngineBindValue::I64(*i as i64), + SeaValue::Int(Some(i)) => EngineBindValue::I64(*i as i64), + SeaValue::BigInt(Some(i)) => EngineBindValue::I64(*i), + SeaValue::BigUnsigned(Some(i)) => EngineBindValue::I64(*i as i64), + SeaValue::Float(Some(f)) => EngineBindValue::F64(*f as f64), + SeaValue::Double(Some(f)) => EngineBindValue::F64(*f), + SeaValue::String(Some(s)) => EngineBindValue::String(s.as_ref().clone()), + SeaValue::Char(Some(c)) => EngineBindValue::String(c.to_string()), + SeaValue::Bytes(Some(b)) => EngineBindValue::Bytes(b.as_ref().clone()), + _ => EngineBindValue::Null, + }) + .collect() +} + +async fn execute_statement_with_optional_tx( + engine: &EngineHandle, + tx_conn: Option, + sql: &str, + bind_values: &[SeaValue], +) -> Result { + match tx_conn { + Some(conn_arc) => { + let engine_bind_values = engine_bind_values_from_sea(bind_values); + let mut conn = conn_arc.lock().await; + conn.execute_sql_with_binds(sql, &engine_bind_values).await } - }}; + None => { + let engine_bind_values = engine_bind_values_from_sea(bind_values); + engine + .execute_sql_with_binds(sql, &engine_bind_values) + .await + } + } } -macro_rules! decode_column { - ($row:expr, $name:expr, $col_name:expr) => {{ - let registry = MODEL_REGISTRY.read().unwrap(); - let prop = registry - .get($name) - .and_then(|s| s.get("properties")) - .and_then(|p| p.get($col_name)); - - let format = prop.and_then(property_format); - let is_decimal = prop - .and_then(|p| p.get("anyOf")) - .and_then(|a| a.as_array()) - .map(|types| { - let has_number = types - .iter() - .any(|t| t.get("type").and_then(|ty| ty.as_str()) == Some("number")); - let has_patterned_string = types.iter().any(|t| { - t.get("type").and_then(|ty| ty.as_str()) == Some("string") - && t.get("pattern").is_some() - }); - has_number && has_patterned_string - }) - .unwrap_or(false); +async fn execute_transaction_sql( + tx_conn: &TransactionConnection, + sql: &str, +) -> Result { + let mut conn = tx_conn.lock().await; + conn.execute_sql(sql).await +} - let json_type = prop.and_then(property_json_type); +fn engine_value_to_rust_value( + value: EngineValue, + schema: &serde_json::Value, + col_name: &str, +) -> RustValue { + let prop = schema + .get("properties") + .and_then(|p| p.get(col_name)) + .map(|col_info| resolve_ref(schema, col_info)); - if is_decimal { - if let Ok(v) = $row.try_get::($col_name) { - RustValue::Decimal(v.to_string()) - } else if let Ok(v) = $row.try_get::($col_name) { - RustValue::Decimal(v) - } else { - RustValue::None - } - } else if format == Some("binary") { - if let Ok(v) = $row.try_get::, _>($col_name) { - RustValue::Blob(v) - } else if let Ok(v) = $row.try_get::($col_name) { - RustValue::Blob(v.into_bytes()) - } else { - RustValue::None - } - } else if let Ok(val) = $row.try_get::($col_name) { - let is_bool = json_type == Some("boolean"); - if is_bool { - RustValue::Bool(val != 0) - } else { - RustValue::BigInt(val) + let format = prop.and_then(property_format); + let is_decimal = prop + .and_then(|p| p.get("anyOf")) + .and_then(|a| a.as_array()) + .map(|types| { + let has_number = types + .iter() + .any(|t| t.get("type").and_then(|ty| ty.as_str()) == Some("number")); + let has_patterned_string = types.iter().any(|t| { + t.get("type").and_then(|ty| ty.as_str()) == Some("string") + && t.get("pattern").is_some() + }); + has_number && has_patterned_string + }) + .unwrap_or(false); + let json_type = prop.and_then(property_json_type); + + if is_decimal { + return match value { + EngineValue::F64(v) => RustValue::Decimal(v.to_string()), + EngineValue::String(v) => RustValue::Decimal(v), + _ => RustValue::None, + }; + } + + if format == Some("binary") { + return match value { + EngineValue::Bytes(v) => RustValue::Blob(v), + EngineValue::String(v) => RustValue::Blob(v.into_bytes()), + _ => RustValue::None, + }; + } + + match value { + EngineValue::I64(v) if json_type == Some("boolean") => RustValue::Bool(v != 0), + EngineValue::I64(v) => RustValue::BigInt(v), + EngineValue::F64(v) => RustValue::Double(v), + EngineValue::Bytes(v) => RustValue::Blob(v), + EngineValue::String(v) => match (json_type, format) { + (_, Some("date-time")) => RustValue::DateTime(v), + (_, Some("date")) => RustValue::Date(v), + (_, Some("uuid")) => RustValue::Uuid(v), + (Some("object"), _) | (Some("array"), _) => { + if let Ok(json_val) = serde_json::from_str(&v) { + RustValue::Json(json_val) + } else { + RustValue::String(v) + } } - } else if let Ok(val) = $row.try_get::($col_name) { - RustValue::Double(val) - } else if let Ok(val) = $row.try_get::, _>($col_name) { - RustValue::Blob(val) - } else if let Ok(val) = $row.try_get::($col_name) { - match (json_type, format) { - (_, Some("date-time")) => RustValue::DateTime(val), - (_, Some("date")) => RustValue::Date(val), - (_, Some("uuid")) => RustValue::Uuid(val), - (Some("object"), _) | (Some("array"), _) => { - if let Ok(json_val) = serde_json::from_str(&val) { - RustValue::Json(json_val) - } else { - RustValue::String(val) - } + _ => RustValue::String(v), + }, + EngineValue::Bool(v) => RustValue::Bool(v), + EngineValue::Null => RustValue::None, + } +} + +fn typed_rows_to_parsed_data( + rows: Vec, + schema: &serde_json::Value, + pk_col: Option<&str>, +) -> Vec<(Option, Vec<(String, RustValue)>)> { + rows.into_iter() + .map(|row| { + let mut row_pk_val = None; + let mut fields = Vec::with_capacity(row.values.len()); + + for (col_name, value) in row.values { + if pk_col == Some(col_name.as_str()) { + row_pk_val = match &value { + EngineValue::I64(v) => Some(v.to_string()), + EngineValue::String(v) => Some(v.clone()), + _ => None, + }; } - _ => RustValue::String(val), + let value = engine_value_to_rust_value(value, schema, &col_name); + fields.push((col_name, value)); } - } else if let Ok(val) = $row.try_get::($col_name) { - RustValue::Bool(val) - } else { - RustValue::None + + (row_pk_val, fields) + }) + .collect() +} + +fn engine_row_string(row: &EngineRow, column_name: &str) -> Option { + row.values + .iter() + .find(|(name, _)| name == column_name) + .and_then(|(_, value)| match value { + EngineValue::String(value) => Some(value.clone()), + EngineValue::I64(value) => Some(value.to_string()), + _ => None, + }) +} + +async fn postgres_catalog_rows( + engine: &EngineHandle, + tx_conn: &Option, + sql: &str, + table_name: &str, + label: &str, +) -> PyResult> { + let values = [EngineBindValue::String(table_name.to_string())]; + let rows = match tx_conn { + Some(conn_arc) => { + let mut conn = conn_arc.lock().await; + conn.fetch_all_sql_with_binds(sql, &values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Failed to inspect {} for '{}': {}", + label, table_name, e + )) + })? + } + None => engine + .fetch_all_sql_with_binds(sql, &values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Failed to inspect {} for '{}': {}", + label, table_name, e + )) + })?, + }; + + Ok(rows) +} + +macro_rules! sea_query_build_for_backend { + ($stmt:expr, $backend:expr) => {{ + match $backend { + crate::state::SqlDialect::Sqlite => $stmt.build(SqliteQueryBuilder), + crate::state::SqlDialect::Postgres => $stmt.build(PostgresQueryBuilder), + } + }}; +} + +macro_rules! sea_query_to_string_for_backend { + ($stmt:expr, $backend:expr) => {{ + match $backend { + crate::state::SqlDialect::Sqlite => $stmt.to_string(SqliteQueryBuilder), + crate::state::SqlDialect::Postgres => $stmt.to_string(PostgresQueryBuilder), } }}; } -/// On Postgres, `sqlx::Any` cannot decode native `uuid` columns. When the model schema marks -/// UUID fields, expand `SELECT *` into explicit columns with `::text` casts so decoding uses -/// text (same representation as SQLite). +/// On Postgres, cast text-like special columns in SELECT output so Python hydration +/// sees the same string representation as SQLite. fn property_json_type(col_info: &serde_json::Value) -> Option<&str> { col_info.get("type").and_then(|t| t.as_str()).or_else(|| { col_info @@ -170,11 +279,12 @@ fn apply_postgres_text_select_columns( table_name: &str, schema: &serde_json::Value, pg_native_enum_columns: &HashSet, + backend: SqlDialect, ) { use sea_query::{Alias, Expr}; let tbl = Alias::new(table_name); - if crate::state::sql_dialect() != crate::state::SqlDialect::Postgres { + if backend != SqlDialect::Postgres { select.column((tbl.clone(), sea_query::Asterisk)); return; } @@ -186,8 +296,9 @@ fn apply_postgres_text_select_columns( let resolved = resolve_ref(schema, col_info); matches!( property_format(resolved), - Some("uuid" | "date-time" | "date") - ) || property_is_enum(resolved) + Some("uuid" | "date-time" | "date" | "decimal") + ) || matches!(property_json_type(resolved), Some("object" | "array")) + || property_is_enum(resolved) }); let need_text_from_native_enum = properties .keys() @@ -199,7 +310,10 @@ fn apply_postgres_text_select_columns( for (col_name, col_info) in properties { let col_iden = Alias::new(col_name.as_str()); let col_info = resolve_ref(schema, col_info); - if matches!(property_format(col_info), Some("uuid" | "date-time" | "date")) + if matches!( + property_format(col_info), + Some("uuid" | "date-time" | "date" | "decimal") + ) || matches!(property_json_type(col_info), Some("object" | "array")) || property_is_enum(col_info) || pg_native_enum_columns.contains(col_name.as_str()) { @@ -214,30 +328,6 @@ fn apply_postgres_text_select_columns( } } -/// Helper to bind Sea-Query values to a SQLx Any query. -fn bind_query<'a>( - mut query: sqlx::query::Query<'a, sqlx::Any, sqlx::any::AnyArguments<'a>>, - values: &'a [SeaValue], -) -> sqlx::query::Query<'a, sqlx::Any, sqlx::any::AnyArguments<'a>> { - for val in values { - query = match val { - SeaValue::Bool(Some(b)) => query.bind(*b), - SeaValue::TinyInt(Some(i)) => query.bind(*i as i64), - SeaValue::SmallInt(Some(i)) => query.bind(*i as i64), - SeaValue::Int(Some(i)) => query.bind(*i as i64), - SeaValue::BigInt(Some(i)) => query.bind(*i), - SeaValue::BigUnsigned(Some(i)) => query.bind(*i as i64), - SeaValue::Float(Some(f)) => query.bind(*f as f64), - SeaValue::Double(Some(f)) => query.bind(*f), - SeaValue::String(Some(s)) => query.bind(s.as_ref().clone()), - SeaValue::Char(Some(c)) => query.bind(c.to_string()), - SeaValue::Bytes(Some(b)) => query.bind(b.as_ref().clone()), - _ => query.bind(Option::::None), - }; - } - query -} - fn resolve_ref<'a>( schema: &'a serde_json::Value, col_info: &'a serde_json::Value, @@ -254,15 +344,15 @@ fn resolve_ref<'a>( /// Maps each table column to its PostgreSQL enum `typname` (``typtype = 'e'``) for the current schema. async fn postgres_enum_udt_by_column( table_name: &str, - pool: &Arc>, - tx_conn: Option>>, + engine: &EngineHandle, + tx_conn: &Option, + backend: SqlDialect, ) -> PyResult> { - if crate::state::sql_dialect() != crate::state::SqlDialect::Postgres { + if backend != SqlDialect::Postgres { return Ok(HashMap::new()); } - let query = sqlx::query( - r#" + let sql = r#" SELECT a.attname::text AS column_name, t.typname::text AS udt_name FROM pg_attribute a JOIN pg_class c ON a.attrelid = c.oid @@ -273,27 +363,12 @@ async fn postgres_enum_udt_by_column( AND t.typtype = 'e' AND a.attnum > 0 AND NOT a.attisdropped - "#, - ) - .bind(table_name); - - let rows = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.fetch_all(&mut *conn).await - } else { - query.fetch_all(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "Failed to inspect enum columns for '{}': {}", - table_name, e - )) - })?; + "#; let mut out = HashMap::new(); - for row in rows { - let column_name: String = row.try_get("column_name").unwrap_or_default(); - let udt_name: String = row.try_get("udt_name").unwrap_or_default(); + for row in postgres_catalog_rows(engine, tx_conn, sql, table_name, "enum columns").await? { + let column_name = engine_row_string(&row, "column_name").unwrap_or_default(); + let udt_name = engine_row_string(&row, "udt_name").unwrap_or_default(); if !column_name.is_empty() && !udt_name.is_empty() { out.insert(column_name, udt_name); } @@ -305,60 +380,46 @@ async fn postgres_enum_udt_by_column( /// Column names on `table_name` in the current schema whose SQL type is `uuid`. async fn postgres_uuid_column_names( table_name: &str, - pool: &Arc>, - tx_conn: Option>>, + engine: &EngineHandle, + tx_conn: &Option, + backend: SqlDialect, ) -> PyResult> { - if crate::state::sql_dialect() != crate::state::SqlDialect::Postgres { + if backend != SqlDialect::Postgres { return Ok(HashSet::new()); } - let query = sqlx::query( - r#" + let sql = r#" SELECT column_name::text FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = $1 AND (data_type = 'uuid' OR udt_name = 'uuid') - "#, + "#; + + Ok( + postgres_catalog_rows(engine, tx_conn, sql, table_name, "uuid columns") + .await? + .into_iter() + .filter_map(|row| { + engine_row_string(&row, "column_name").filter(|name| !name.is_empty()) + }) + .collect(), ) - .bind(table_name); - - let rows = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.fetch_all(&mut *conn).await - } else { - query.fetch_all(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "Failed to inspect uuid columns for '{}': {}", - table_name, e - )) - })?; - - Ok(rows - .into_iter() - .filter_map(|r| { - r.try_get::("column_name") - .ok() - .filter(|n| !n.is_empty()) - }) - .collect()) } /// For each column whose SQL type is a date or timestamp family, the ``CAST ( … AS … )`` target /// (``date``, ``timestamp``, ``timestamptz``) so parameters are not sent as untyped text. async fn postgres_temporal_cast_by_column( table_name: &str, - pool: &Arc>, - tx_conn: Option>>, + engine: &EngineHandle, + tx_conn: &Option, + backend: SqlDialect, ) -> PyResult> { - if crate::state::sql_dialect() != crate::state::SqlDialect::Postgres { + if backend != SqlDialect::Postgres { return Ok(HashMap::new()); } - let query = sqlx::query( - r#" + let sql = r#" SELECT column_name::text, CASE data_type::text WHEN 'timestamp without time zone' THEN 'timestamp' @@ -374,31 +435,14 @@ async fn postgres_temporal_cast_by_column( 'timestamp with time zone', 'date' ) - "#, - ) - .bind(table_name); - - let rows = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.fetch_all(&mut *conn).await - } else { - query.fetch_all(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "Failed to inspect temporal columns for '{}': {}", - table_name, e - )) - })?; + "#; let mut out = HashMap::new(); - for row in rows { - let col: String = row.try_get("column_name").unwrap_or_default(); - let ct: Option = row.try_get("cast_type").ok(); - if let (cn, Some(cast)) = (col, ct) { - if !cn.is_empty() && !cast.is_empty() { - out.insert(cn, cast); - } + for row in postgres_catalog_rows(engine, tx_conn, sql, table_name, "temporal columns").await? { + let column_name = engine_row_string(&row, "column_name").unwrap_or_default(); + let cast_type = engine_row_string(&row, "cast_type").unwrap_or_default(); + if !column_name.is_empty() && !cast_type.is_empty() { + out.insert(column_name, cast_type); } } Ok(out) @@ -434,36 +478,34 @@ fn schema_value_expr( enum_udt: &HashMap, uuid_columns: &HashSet, ts_cast: &HashMap, + backend: SqlDialect, ) -> SimpleExpr { let col_info = schema_property(schema, col_name); if let serde_json::Value::String(s) = value - && crate::state::sql_dialect() == crate::state::SqlDialect::Postgres + && backend == SqlDialect::Postgres && let Some(tn) = postgres_enum_type_name_for_column(col_name, enum_udt, col_info) { return Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) .cast_as(Alias::new(tn.as_str())); } - if value.is_null() - && crate::state::sql_dialect() == crate::state::SqlDialect::Postgres - && uuid_columns.contains(col_name) - { + if value.is_null() && backend == SqlDialect::Postgres && uuid_columns.contains(col_name) { return Expr::value(sea_query::Value::String(None)).cast_as("uuid"); } if let serde_json::Value::String(s) = value - && crate::state::sql_dialect() == crate::state::SqlDialect::Postgres + && backend == SqlDialect::Postgres && uuid_columns.contains(col_name) && uuid::Uuid::parse_str(s).is_ok() { return Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))).cast_as("uuid"); } if value.is_null() - && crate::state::sql_dialect() == crate::state::SqlDialect::Postgres + && backend == SqlDialect::Postgres && let Some(cast) = ts_cast.get(col_name) { return Expr::value(sea_query::Value::String(None)).cast_as(Alias::new(cast.as_str())); } if let serde_json::Value::String(s) = value - && crate::state::sql_dialect() == crate::state::SqlDialect::Postgres + && backend == SqlDialect::Postgres && let Some(cast) = ts_cast.get(col_name) { return Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) @@ -478,22 +520,28 @@ fn schema_value_expr( .unwrap_or(false); match value { + value + if backend == SqlDialect::Postgres && matches!(json_type, Some("object" | "array")) => + { + if value.is_null() { + Expr::value(sea_query::Value::String(None)).cast_as("json") + } else { + Expr::value(sea_query::Value::String(Some(Box::new(value.to_string())))) + .cast_as("json") + } + } serde_json::Value::String(s) - if crate::state::sql_dialect() == crate::state::SqlDialect::Postgres - && format == Some("uuid") => + if backend == SqlDialect::Postgres && format == Some("uuid") => { Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))).cast_as("uuid") } serde_json::Value::String(s) - if crate::state::sql_dialect() == crate::state::SqlDialect::Postgres - && format == Some("date-time") => + if backend == SqlDialect::Postgres && format == Some("date-time") => { - Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) - .cast_as("timestamptz") + Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))).cast_as("timestamptz") } serde_json::Value::String(s) - if crate::state::sql_dialect() == crate::state::SqlDialect::Postgres - && format == Some("date") => + if backend == SqlDialect::Postgres && format == Some("date") => { Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))).cast_as("date") } @@ -511,11 +559,9 @@ fn schema_value_expr( Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) } } - serde_json::Value::String(s) if format == Some("binary") => { - Expr::value(sea_query::Value::Bytes(Some(Box::new( - s.as_bytes().to_vec(), - )))) - } + serde_json::Value::String(s) if format == Some("binary") => Expr::value( + sea_query::Value::Bytes(Some(Box::new(s.as_bytes().to_vec()))), + ), serde_json::Value::String(s) if is_decimal => { if let Ok(parsed) = s.parse::() { Expr::value(sea_query::Value::Double(Some(parsed))) @@ -536,8 +582,7 @@ fn schema_value_expr( Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) } serde_json::Value::Bool(b) - if json_type == Some("boolean") - && crate::state::sql_dialect() == crate::state::SqlDialect::Sqlite => + if json_type == Some("boolean") && backend == SqlDialect::Sqlite => { Expr::value(sea_query::Value::BigInt(Some(if *b { 1 } else { 0 }))) } @@ -547,6 +592,20 @@ fn schema_value_expr( } } +fn backend_column_value_expr( + col_name: &str, + value: sea_query::Value, + uuid_columns: &HashSet, + backend: SqlDialect, +) -> SimpleExpr { + let expr = Expr::value(value); + if backend == SqlDialect::Postgres && uuid_columns.contains(col_name) { + expr.cast_as("uuid") + } else { + expr + } +} + #[pyfunction] #[pyo3(signature = (parent_tx_id=None))] pub fn begin_transaction( @@ -563,9 +622,7 @@ pub fn begin_transaction( drop(parent); let savepoint_name = format!("sp_{}", tx_id.replace('-', "_")); - let mut locked_conn = conn.lock().await; - sqlx::query(&format!("SAVEPOINT {savepoint_name}")) - .execute(&mut *locked_conn) + execute_transaction_sql(&conn, &format!("SAVEPOINT {savepoint_name}")) .await .map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!( @@ -573,32 +630,18 @@ pub fn begin_transaction( e )) })?; - drop(locked_conn); TRANSACTION_REGISTRY.insert( tx_id.clone(), TransactionHandle::nested(conn, savepoint_name), ); } else { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") + let engine = active_engine()?; + let conn = engine.begin_transaction_connection().await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to BEGIN: {}", e)) })?; - let mut conn = pool.acquire().await.map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "Failed to acquire connection: {}", - e - )) - })?; - - sqlx::query("BEGIN") - .execute(&mut *conn) - .await - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to BEGIN: {}", e)) - })?; - - TRANSACTION_REGISTRY.insert(tx_id.clone(), TransactionHandle::root(conn.detach())); + TRANSACTION_REGISTRY.insert(tx_id.clone(), TransactionHandle::root(conn)); } Ok(tx_id) @@ -613,20 +656,20 @@ pub fn commit_transaction(py: Python<'_>, tx_id: String) -> PyResult, tx_id: String) -> PyResult( let cls_py = cls.unbind(); pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = get_transaction_connection(tx_id); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); let pg_native_enum_cols: HashSet = { - let m = postgres_enum_udt_by_column(&table_name, &pool, tx_conn.clone()).await?; + let m = postgres_enum_udt_by_column(&table_name, &engine, &tx_conn, backend).await?; m.keys().cloned().collect() }; // ... same sql generation ... - let (sql, pk_col) = { + let (sql, pk_col, schema_for_decode) = { let registry = MODEL_REGISTRY.read().map_err(|_| { pyo3::exceptions::PyRuntimeError::new_err("Failed to lock registry") })?; @@ -741,38 +782,38 @@ pub fn fetch_all<'py>( } } let mut stmt = Query::select(); - apply_postgres_text_select_columns(&mut stmt, &table_name, schema, &pg_native_enum_cols); - let s = sea_query_to_string!(stmt.from(Alias::new(&table_name))); - (s, pk) + apply_postgres_text_select_columns( + &mut stmt, + &table_name, + schema, + &pg_native_enum_cols, + backend, + ); + let s = sea_query_to_string_for_backend!(stmt.from(Alias::new(&table_name)), backend); + (s, pk, schema.clone()) }; - let rows = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - sqlx::query(&sql).fetch_all(&mut *conn).await - } else { - sqlx::query(&sql).fetch_all(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)))?; - - let mut parsed_data = Vec::with_capacity(rows.len()); - for row in rows { - let mut row_pk_val = None; - if let Some(ref pk_name) = pk_col { - if let Ok(val) = row.try_get::(pk_name.as_str()) { - row_pk_val = Some(val.to_string()); - } else if let Ok(val) = row.try_get::(pk_name.as_str()) { - row_pk_val = Some(val); - } + let parsed_data = match tx_conn { + Some(conn_arc) => { + let mut conn = conn_arc.lock().await; + let rows = conn + .fetch_all_sql_with_binds(&sql, &[]) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)) + })?; + typed_rows_to_parsed_data(rows, &schema_for_decode, pk_col.as_deref()) } - - let mut fields = Vec::with_capacity(row.columns().len()); - for col in row.columns() { - let col_name = col.name(); - let val = decode_column!(row, &name, col_name); - fields.push((col_name.to_string(), val)); + None => { + let rows = engine + .fetch_all_sql_with_binds(&sql, &[]) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)) + })?; + typed_rows_to_parsed_data(rows, &schema_for_decode, pk_col.as_deref()) } - parsed_data.push((row_pk_val, fields)); - } + }; Python::attach(|py| { let results = pyo3::types::PyList::empty(py); @@ -857,21 +898,20 @@ pub fn fetch_one<'py>( } pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = get_transaction_connection(tx_id); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); let pg_native_enum_cols: HashSet = { - let m = postgres_enum_udt_by_column(&table_name, &pool, tx_conn.clone()).await?; + let m = postgres_enum_udt_by_column(&table_name, &engine, &tx_conn, backend).await?; m.keys().cloned().collect() }; // ... sql logic ... - let (sql, bind_values, _pk_col_name) = { + let (sql, bind_values, _pk_col_name, schema_for_decode) = { let registry = MODEL_REGISTRY.read().map_err(|_| { pyo3::exceptions::PyRuntimeError::new_err("Failed to lock registry") })?; @@ -894,7 +934,13 @@ pub fn fetch_one<'py>( let pk_name = pk.ok_or_else(|| pyo3::exceptions::PyRuntimeError::new_err("No primary key"))?; let mut stmt = Query::select(); - apply_postgres_text_select_columns(&mut stmt, &table_name, schema, &pg_native_enum_cols); + apply_postgres_text_select_columns( + &mut stmt, + &table_name, + schema, + &pg_native_enum_cols, + backend, + ); let no_enum_udt = HashMap::new(); let no_uuid = HashSet::new(); let no_ts: HashMap = HashMap::new(); @@ -905,33 +951,44 @@ pub fn fetch_one<'py>( &no_enum_udt, &no_uuid, &no_ts, + backend, + ); + let (s, values) = sea_query_build_for_backend!( + stmt.from(Alias::new(&table_name)) + .and_where(Expr::col(Alias::new(&pk_name)).eq(pk_expr)), + backend ); - let (s, values) = sea_query_build!(stmt - .from(Alias::new(&table_name)) - .and_where(Expr::col(Alias::new(&pk_name)).eq(pk_expr))); - (s, values, pk_name) + (s, values, pk_name, schema.clone()) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - let row = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.fetch_optional(&mut *conn).await - } else { - query.fetch_optional(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)))?; - - let parsed_row = match row { - Some(row) => { - let mut fields = Vec::with_capacity(row.columns().len()); - for col in row.columns() { - let col_name = col.name(); - let val = decode_column!(row, &name, col_name); - fields.push((col_name.to_string(), val)); - } - Some(fields) + let parsed_row = match tx_conn { + Some(conn_arc) => { + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + let mut conn = conn_arc.lock().await; + let rows = conn + .fetch_all_sql_with_binds(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)) + })?; + typed_rows_to_parsed_data(rows, &schema_for_decode, None) + .into_iter() + .next() + .map(|(_, fields)| fields) + } + None => { + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + let rows = engine + .fetch_all_sql_with_binds(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)) + })?; + typed_rows_to_parsed_data(rows, &schema_for_decode, None) + .into_iter() + .next() + .map(|(_, fields)| fields) } - None => None, }; match parsed_row { @@ -977,12 +1034,11 @@ pub fn save_record( tx_id: Option, ) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = get_transaction_connection(tx_id); + (engine, tx_conn, backend) }; // ... schema and record logic ... @@ -1023,9 +1079,11 @@ pub fn save_record( } let table_name = name.to_lowercase(); - let enum_udt = postgres_enum_udt_by_column(&table_name, &pool, tx_conn.clone()).await?; - let uuid_columns = postgres_uuid_column_names(&table_name, &pool, tx_conn.clone()).await?; - let ts_cast = postgres_temporal_cast_by_column(&table_name, &pool, tx_conn.clone()).await?; + let enum_udt = postgres_enum_udt_by_column(&table_name, &engine, &tx_conn, backend).await?; + let uuid_columns = + postgres_uuid_column_names(&table_name, &engine, &tx_conn, backend).await?; + let ts_cast = + postgres_temporal_cast_by_column(&table_name, &engine, &tx_conn, backend).await?; let (sql, bind_values, needs_postgres_returning) = { let mut columns = Vec::new(); let mut values = Vec::new(); @@ -1050,6 +1108,7 @@ pub fn save_record( &enum_udt, &uuid_columns, &ts_cast, + backend, )); } let mut insert_stmt = InsertStatement::new() @@ -1073,84 +1132,68 @@ pub fn save_record( insert_stmt.on_conflict(on_conflict); } } - let needs_postgres_returning = crate::state::sql_dialect() - == crate::state::SqlDialect::Postgres + let needs_postgres_returning = backend == crate::state::SqlDialect::Postgres && pk_col.is_some() && pk_is_auto && !pk_provided; - let (mut sql, values) = sea_query_build!(insert_stmt); - if needs_postgres_returning - && let Some(pk) = pk_col.as_ref() - { + let (mut sql, values) = sea_query_build_for_backend!(insert_stmt, backend); + if needs_postgres_returning && let Some(pk) = pk_col.as_ref() { sql.push_str(&format!(" RETURNING \"{}\"", pk)); } (sql, values, needs_postgres_returning) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - if needs_postgres_returning { - let row = query.fetch_one(&mut *conn).await.map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e)) - })?; - let id: i64 = row.try_get(0).unwrap_or(0); - Ok((id > 0).then_some(id)) - } else { - let res = query.execute(&mut *conn).await; - if res.is_err() { - return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( - "Save failed: {}", - res.err().unwrap() - ))); - } - let exec_res = res.unwrap(); - let mut lid = exec_res.last_insert_id(); - if crate::state::sql_dialect() == crate::state::SqlDialect::Sqlite - && (lid.is_none() || lid == Some(0)) - && let Ok(row) = sqlx::query("SELECT last_insert_rowid()") - .fetch_one(&mut *conn) + match tx_conn { + Some(conn_arc) => { + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + let mut conn = conn_arc.lock().await; + if needs_postgres_returning { + let rows = conn + .fetch_all_sql_with_binds(&sql, &engine_bind_values) .await - { - let id: i64 = row.try_get(0).unwrap_or(0); - if id > 0 { - lid = Some(id); - } + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e)) + })?; + let id = rows + .first() + .and_then(|row| row.values.first()) + .and_then(|(_, value)| value.as_i64()) + .unwrap_or(0); + Ok((id > 0).then_some(id)) + } else { + let exec_res = conn + .execute_sql_with_binds_result(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e)) + })?; + Ok(exec_res.last_insert_id) } - Ok(lid) } - } else { - let mut conn = pool.acquire().await.map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Pool acquire failed: {}", e)) - })?; - if needs_postgres_returning { - let row = query.fetch_one(&mut *conn).await.map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e)) - })?; - let id: i64 = row.try_get(0).unwrap_or(0); - Ok((id > 0).then_some(id)) - } else { - let res = query.execute(&mut *conn).await; - if res.is_err() { - return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( - "Save failed: {}", - res.err().unwrap() - ))); - } - let exec_res = res.unwrap(); - let mut lid = exec_res.last_insert_id(); - if crate::state::sql_dialect() == crate::state::SqlDialect::Sqlite - && (lid.is_none() || lid == Some(0)) - && let Ok(row) = sqlx::query("SELECT last_insert_rowid()") - .fetch_one(&mut *conn) + None => { + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + if needs_postgres_returning { + let rows = engine + .fetch_all_sql_with_binds(&sql, &engine_bind_values) .await - { - let id: i64 = row.try_get(0).unwrap_or(0); - if id > 0 { - lid = Some(id); - } + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e)) + })?; + let id = rows + .first() + .and_then(|row| row.values.first()) + .and_then(|(_, value)| value.as_i64()) + .unwrap_or(0); + Ok((id > 0).then_some(id)) + } else { + let exec_res = engine + .execute_sql_with_binds_result(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Save failed: {}", e)) + })?; + Ok(exec_res.last_insert_id) } - Ok(lid) } } }) @@ -1165,14 +1208,11 @@ pub fn save_bulk_records( tx_id: Option, ) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err( - "Engine not initialized. Call connect() first.", - ) - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = get_transaction_connection(tx_id); + (engine, tx_conn, backend) }; let schema = { @@ -1217,9 +1257,11 @@ pub fn save_bulk_records( } let table_name = name.to_lowercase(); - let enum_udt = postgres_enum_udt_by_column(&table_name, &pool, tx_conn.clone()).await?; - let uuid_columns = postgres_uuid_column_names(&table_name, &pool, tx_conn.clone()).await?; - let ts_cast = postgres_temporal_cast_by_column(&table_name, &pool, tx_conn.clone()).await?; + let enum_udt = postgres_enum_udt_by_column(&table_name, &engine, &tx_conn, backend).await?; + let uuid_columns = + postgres_uuid_column_names(&table_name, &engine, &tx_conn, backend).await?; + let ts_cast = + postgres_temporal_cast_by_column(&table_name, &engine, &tx_conn, backend).await?; let (sql, bind_values) = { let mut insert_stmt = InsertStatement::new() .into_table(Alias::new(&table_name)) @@ -1258,6 +1300,7 @@ pub fn save_bulk_records( &enum_udt, &uuid_columns, &ts_cast, + backend, )); } insert_stmt.values(row_values).map_err(|e| { @@ -1268,25 +1311,21 @@ pub fn save_bulk_records( })?; } - let (s, values) = sea_query_build!(insert_stmt); + let (s, values) = sea_query_build_for_backend!(insert_stmt, backend); (s, values) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - let result = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "Bulk save failed for '{}': {}", - name, e - )) - })?; - - Ok(result.rows_affected()) + let rows_affected = + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Bulk save failed for '{}': {}", + name, e + )) + })?; + + Ok(rows_affected) }) } @@ -1314,21 +1353,20 @@ pub fn fetch_filtered<'py>( })?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = get_transaction_connection(tx_id); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); let pg_native_enum_cols: HashSet = { - let m = postgres_enum_udt_by_column(&table_name, &pool, tx_conn.clone()).await?; + let m = postgres_enum_udt_by_column(&table_name, &engine, &tx_conn, backend).await?; m.keys().cloned().collect() }; // ... - let (sql, bind_values, pk_col) = { + let (sql, bind_values, pk_col, schema_for_decode) = { let registry = MODEL_REGISTRY.read().map_err(|_| { pyo3::exceptions::PyRuntimeError::new_err("Failed to lock registry") })?; @@ -1355,6 +1393,7 @@ pub fn fetch_filtered<'py>( &table_name, schema, &pg_native_enum_cols, + backend, ); select.from(Alias::new(&table_name)); @@ -1371,13 +1410,17 @@ pub fn fetch_filtered<'py>( Expr::col((Alias::new(&table_name), Alias::new(pk_name))) .equals((join_table.clone(), target_col.clone())), ); - select.and_where( - Expr::col((join_table.clone(), source_col.clone())) - .eq(query_def.value_rhs_simple_expr(&m2m.source_col, &m2m.source_id, true)), - ); + select.and_where(Expr::col((join_table.clone(), source_col.clone())).eq( + query_def.value_rhs_simple_expr_for_backend( + &m2m.source_col, + &m2m.source_id, + true, + backend, + ), + )); } - select.cond_where(query_def.to_condition()); + select.cond_where(query_def.to_condition_for_backend(backend)); if let Some(ref orders) = query_def.order_by { for order in orders { let col = Alias::new(&order.column); @@ -1395,38 +1438,33 @@ pub fn fetch_filtered<'py>( if let Some(offset) = query_def.offset { select.offset(offset); } - let (s, values) = sea_query_build!(select); - (s, values, pk) + let (s, values) = sea_query_build_for_backend!(select, backend); + (s, values, pk, schema.clone()) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - let rows = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.fetch_all(&mut *conn).await - } else { - query.fetch_all(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)))?; - - let mut parsed_data = Vec::with_capacity(rows.len()); - for row in rows { - let mut row_pk_val = None; - if let Some(ref pk_name) = pk_col { - if let Ok(val) = row.try_get::(pk_name.as_str()) { - row_pk_val = Some(val.to_string()); - } else if let Ok(val) = row.try_get::(pk_name.as_str()) { - row_pk_val = Some(val); - } + let parsed_data = match tx_conn { + Some(conn_arc) => { + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + let mut conn = conn_arc.lock().await; + let rows = conn + .fetch_all_sql_with_binds(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)) + })?; + typed_rows_to_parsed_data(rows, &schema_for_decode, pk_col.as_deref()) } - - let mut fields = Vec::with_capacity(row.columns().len()); - for col in row.columns() { - let col_name = col.name(); - let val = decode_column!(row, &name, col_name); - fields.push((col_name.to_string(), val)); + None => { + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + let rows = engine + .fetch_all_sql_with_binds(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Fetch failed: {}", e)) + })?; + typed_rows_to_parsed_data(rows, &schema_for_decode, pk_col.as_deref()) } - parsed_data.push((row_pk_val, fields)); - } + }; Python::attach(|py| { let results = pyo3::types::PyList::empty(py); @@ -1493,12 +1531,15 @@ pub fn count_filtered( })?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); @@ -1541,28 +1582,51 @@ pub fn count_filtered( Expr::col((Alias::new(&table_name), Alias::new(pk_name))) .equals((join_table.clone(), target_col.clone())), ); - select.and_where( - Expr::col((join_table.clone(), source_col.clone())) - .eq(query_def.value_rhs_simple_expr(&m2m.source_col, &m2m.source_id, true)), - ); + select.and_where(Expr::col((join_table.clone(), source_col.clone())).eq( + query_def.value_rhs_simple_expr_for_backend( + &m2m.source_col, + &m2m.source_id, + true, + backend, + ), + )); } else { select.from(Alias::new(&table_name)); } - select.cond_where(query_def.to_condition()); - sea_query_build!(select) + select.cond_where(query_def.to_condition_for_backend(backend)); + sea_query_build_for_backend!(select, backend) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - let row = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.fetch_one(&mut *conn).await - } else { - query.fetch_one(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Count failed: {}", e)))?; + let engine_bind_values = engine_bind_values_from_sea(&bind_values.0); + let count = match tx_conn { + Some(conn_arc) => { + let mut conn = conn_arc.lock().await; + let rows = conn + .fetch_all_sql_with_binds(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Count failed: {}", e)) + })?; + rows.first() + .and_then(|row| row.values.first()) + .and_then(|(_, value)| value.as_i64()) + .unwrap_or(0) + } + None => { + let rows = engine + .fetch_all_sql_with_binds(&sql, &engine_bind_values) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Count failed: {}", e)) + })?; + rows.first() + .and_then(|row| row.values.first()) + .and_then(|(_, value)| value.as_i64()) + .unwrap_or(0) + } + }; - let count: i64 = row.try_get(0).unwrap_or(0); Ok(count) }) } @@ -1591,12 +1655,15 @@ pub fn delete_record( tx_id: Option, ) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); @@ -1633,21 +1700,22 @@ pub fn delete_record( &no_enum_udt, &no_uuid, &no_ts, + backend, + ); + let (s, values) = sea_query_build_for_backend!( + Query::delete() + .from_table(Alias::new(&table_name)) + .and_where(Expr::col(Alias::new(&pk_name)).eq(pk_expr)), + backend ); - let (s, values) = sea_query_build!(Query::delete() - .from_table(Alias::new(&table_name)) - .and_where(Expr::col(Alias::new(&pk_name)).eq(pk_expr))); (s, values) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Delete failed: {}", e)))?; + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Delete failed: {}", e)) + })?; Ok(true) }) @@ -1667,12 +1735,15 @@ pub fn delete_filtered( })?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); @@ -1681,23 +1752,21 @@ pub fn delete_filtered( let mut delete = Query::delete(); delete .from_table(Alias::new(&table_name)) - .cond_where(query_def.to_condition()); - sea_query_build!(delete) + .cond_where(query_def.to_condition_for_backend(backend)); + sea_query_build_for_backend!(delete, backend) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - let result = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Delete failed: {}", e)))?; + let rows_affected = + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Delete failed: {}", e)) + })?; // After bulk delete, we MUST clear the Identity Map for this model to avoid stale objects IDENTITY_MAP.retain(|(m_name, _), _| m_name != &name); - Ok(result.rows_affected()) + Ok(rows_affected) }) } @@ -1724,18 +1793,19 @@ pub fn update_filtered( })?; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = get_transaction_connection(tx_id); + (engine, tx_conn, backend) }; let table_name = name.to_lowercase(); - let enum_udt = postgres_enum_udt_by_column(&table_name, &pool, tx_conn.clone()).await?; - let uuid_columns = postgres_uuid_column_names(&table_name, &pool, tx_conn.clone()).await?; - let ts_cast = postgres_temporal_cast_by_column(&table_name, &pool, tx_conn.clone()).await?; + let enum_udt = postgres_enum_udt_by_column(&table_name, &engine, &tx_conn, backend).await?; + let uuid_columns = + postgres_uuid_column_names(&table_name, &engine, &tx_conn, backend).await?; + let ts_cast = + postgres_temporal_cast_by_column(&table_name, &engine, &tx_conn, backend).await?; // ... sql ... let (sql, bind_values) = { let registry = MODEL_REGISTRY.read().map_err(|_| { @@ -1746,7 +1816,7 @@ pub fn update_filtered( })?; let mut update = UpdateStatement::new() .table(Alias::new(&table_name)) - .cond_where(query_def.to_condition()) + .cond_where(query_def.to_condition_for_backend(backend)) .to_owned(); for (key, value) in update_map { update.value( @@ -1758,25 +1828,24 @@ pub fn update_filtered( &enum_udt, &uuid_columns, &ts_cast, + backend, ), ); } - sea_query_build!(update) + sea_query_build_for_backend!(update, backend) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - let result = if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("Update failed: {}", e)))?; + let rows_affected = + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Update failed: {}", e)) + })?; // After bulk update, we MUST clear the Identity Map for this model to avoid stale objects IDENTITY_MAP.retain(|(m_name, _), _| m_name != &name); - Ok(result.rows_affected()) + Ok(rows_affected) }) } @@ -1797,39 +1866,47 @@ pub fn add_m2m_links<'py>( .map(|id| python_to_sea_value(id)) .collect::>>()?; - let (sql, bind_values) = { - let mut insert = InsertStatement::new() - .into_table(Alias::new(&join_table)) - .columns(vec![Alias::new(&source_col), Alias::new(&target_col)]) - .to_owned(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }); + (engine, tx_conn, backend) + }; + let uuid_columns = + postgres_uuid_column_names(&join_table, &engine, &tx_conn, backend).await?; - for t_id in t_ids { - insert - .values(vec![Expr::value(s_id.clone()), Expr::value(t_id)]) - .unwrap(); - } - sea_query_build!(insert) - }; + let (sql, bind_values) = { + let mut insert = InsertStatement::new() + .into_table(Alias::new(&join_table)) + .columns(vec![Alias::new(&source_col), Alias::new(&target_col)]) + .to_owned(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + for t_id in t_ids { + insert + .values(vec![ + backend_column_value_expr( + &source_col, + s_id.clone(), + &uuid_columns, + backend, + ), + backend_column_value_expr(&target_col, t_id, &uuid_columns, backend), + ]) + .unwrap(); + } + sea_query_build_for_backend!(insert, backend) }; - let query = bind_query(sqlx::query(&sql), &bind_values.0); - if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Add M2M links failed: {}", e)) - })?; + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Add M2M links failed: {}", e)) + })?; Ok(()) }) @@ -1852,30 +1929,49 @@ pub fn remove_m2m_links<'py>( .map(|id| python_to_sea_value(id)) .collect::>>()?; - let (sql, bind_values) = sea_query_build!(Query::delete() - .from_table(Alias::new(&join_table)) - .and_where(Expr::col(Alias::new(&source_col)).eq(s_id)) - .and_where(Expr::col(Alias::new(&target_col)).is_in(t_ids))); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }); + (engine, tx_conn, backend) }; - - let query = bind_query(sqlx::query(&sql), &bind_values.0); - if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Remove M2M links failed: {}", e)) - })?; + let uuid_columns = + postgres_uuid_column_names(&join_table, &engine, &tx_conn, backend).await?; + + let (sql, bind_values) = sea_query_build_for_backend!( + Query::delete() + .from_table(Alias::new(&join_table)) + .and_where( + Expr::col(Alias::new(&source_col)).eq(backend_column_value_expr( + &source_col, + s_id, + &uuid_columns, + backend + )) + ) + .and_where( + Expr::col(Alias::new(&target_col)).is_in( + t_ids + .into_iter() + .map(|t_id| { + backend_column_value_expr(&target_col, t_id, &uuid_columns, backend) + }) + .collect::>() + ) + ), + backend + ); + + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Remove M2M links failed: {}", e)) + })?; Ok(()) }) @@ -1892,29 +1988,39 @@ pub fn clear_m2m_links<'py>( ) -> PyResult> { let s_id = python_to_sea_value(source_id)?; - let (sql, bind_values) = sea_query_build!(Query::delete() - .from_table(Alias::new(&join_table)) - .and_where(Expr::col(Alias::new(&source_col)).eq(s_id))); - pyo3_async_runtimes::tokio::future_into_py(py, async move { - let (pool, tx_conn) = { - let pool = engine_pool().ok_or_else(|| { - pyo3::exceptions::PyRuntimeError::new_err("Engine not initialized") - })?; - let tx_conn = get_conn!(pool, tx_id); - (pool, tx_conn) + let (engine, tx_conn, backend) = { + let engine = active_engine()?; + let backend = engine.backend(); + let tx_conn = tx_id.and_then(|id| { + TRANSACTION_REGISTRY + .get(&id) + .map(|tx| tx.value().conn.clone()) + }); + (engine, tx_conn, backend) }; - - let query = bind_query(sqlx::query(&sql), &bind_values.0); - if let Some(conn_arc) = tx_conn { - let mut conn = conn_arc.lock().await; - query.execute(&mut *conn).await - } else { - query.execute(pool.as_ref()).await - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Clear M2M links failed: {}", e)) - })?; + let uuid_columns = + postgres_uuid_column_names(&join_table, &engine, &tx_conn, backend).await?; + + let (sql, bind_values) = sea_query_build_for_backend!( + Query::delete() + .from_table(Alias::new(&join_table)) + .and_where( + Expr::col(Alias::new(&source_col)).eq(backend_column_value_expr( + &source_col, + s_id, + &uuid_columns, + backend + )) + ), + backend + ); + + execute_statement_with_optional_tx(&engine, tx_conn, &sql, &bind_values.0) + .await + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Clear M2M links failed: {}", e)) + })?; Ok(()) }) diff --git a/src/query.rs b/src/query.rs index 31efc59..2871f72 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,4 +1,4 @@ -use crate::state::{sql_dialect, SqlDialect, MODEL_REGISTRY}; +use crate::state::{MODEL_REGISTRY, SqlDialect}; use sea_query::{Alias, Condition, Expr, SimpleExpr}; use serde::Deserialize; use serde_json::Value; @@ -41,18 +41,20 @@ pub struct QueryDef { } impl QueryDef { - pub fn to_condition(&self) -> Condition { + pub fn to_condition_for_backend(&self, backend: SqlDialect) -> Condition { let mut condition = Condition::all(); for node in &self.where_clause { - condition = condition.add(self.node_to_condition(node)); + condition = condition.add(self.node_to_condition_for_backend(node, backend)); } condition } - fn node_to_condition(&self, node: &QueryNode) -> Condition { + fn node_to_condition_for_backend(&self, node: &QueryNode, backend: SqlDialect) -> Condition { if node.is_compound { - let left_cond = self.node_to_condition(node.left.as_ref().unwrap()); - let right_cond = self.node_to_condition(node.right.as_ref().unwrap()); + let left_cond = + self.node_to_condition_for_backend(node.left.as_ref().unwrap(), backend); + let right_cond = + self.node_to_condition_for_backend(node.right.as_ref().unwrap(), backend); match node.operator.as_str() { "OR" => Condition::any().add(left_cond).add(right_cond), @@ -64,33 +66,46 @@ impl QueryDef { let val = node.value.as_ref().unwrap(); let col = Expr::col(Alias::new(col_name)); - let expr: SimpleExpr = match node.operator.as_str() { - "==" => col.eq(self.value_rhs_simple_expr(col_name, val, false)), - "!=" => col.ne(self.value_rhs_simple_expr(col_name, val, false)), - "<" => col.lt(self.value_rhs_simple_expr(col_name, val, false)), - "<=" => col.lte(self.value_rhs_simple_expr(col_name, val, false)), - ">" => col.gt(self.value_rhs_simple_expr(col_name, val, false)), - ">=" => col.gte(self.value_rhs_simple_expr(col_name, val, false)), - "IN" => { - if let Some(vals) = val.as_array() { - let rhs: Vec = vals - .iter() - .map(|v| self.value_rhs_simple_expr(col_name, v, false)) - .collect(); - col.is_in(rhs) - } else { - col.eq(self.value_rhs_simple_expr(col_name, val, false)) + let expr: SimpleExpr = + match node.operator.as_str() { + "==" => col + .eq(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + "!=" => col + .ne(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + "<" => col + .lt(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + "<=" => col + .lte(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + ">" => col + .gt(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + ">=" => col + .gte(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + "IN" => { + if let Some(vals) = val.as_array() { + let rhs: Vec = vals + .iter() + .map(|v| { + self.value_rhs_simple_expr_for_backend( + col_name, v, false, backend, + ) + }) + .collect(); + col.is_in(rhs) + } else { + col.eq(self + .value_rhs_simple_expr_for_backend(col_name, val, false, backend)) + } } - } - "LIKE" => { - let pattern = match val { - Value::String(s) => s.clone(), - _ => val.to_string(), - }; - col.like(pattern) - } - _ => col.eq(self.value_rhs_simple_expr(col_name, val, false)), - }; + "LIKE" => { + let pattern = match val { + Value::String(s) => s.clone(), + _ => val.to_string(), + }; + col.like(pattern) + } + _ => col + .eq(self.value_rhs_simple_expr_for_backend(col_name, val, false, backend)), + }; Condition::all().add(expr) } } @@ -98,19 +113,19 @@ impl QueryDef { /// Right-hand side for a filter comparison. /// /// On Postgres, UUID columns compared to JSON string parameters need an explicit - /// `CAST(... AS uuid)` so the bind stays text-typed (compatible with `sqlx::Any`) - /// while the comparison is `uuid = uuid`. + /// `CAST(... AS uuid)` so the comparison is `uuid = uuid`. /// /// `infer_uuid_without_schema` is used for M2M join filters where the RHS is a UUID /// string but the join column is not described on the queried model's schema. - pub fn value_rhs_simple_expr( + pub fn value_rhs_simple_expr_for_backend( &self, col_name: &str, val: &Value, infer_uuid_without_schema: bool, + backend: SqlDialect, ) -> SimpleExpr { if let Value::String(s) = val { - if sql_dialect() == SqlDialect::Postgres { + if backend == SqlDialect::Postgres { if uuid::Uuid::parse_str(s).is_ok() { let schema_uuid = model_column_is_uuid(&self.model_name, col_name); if schema_uuid || infer_uuid_without_schema { @@ -131,15 +146,11 @@ impl QueryDef { return Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) .cast_as("date"); } - if model_column_format(&self.model_name, col_name).as_deref() - == Some("date-time") - { + if model_column_format(&self.model_name, col_name).as_deref() == Some("date-time") { return Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) .cast_as("timestamptz"); } - if model_column_format(&self.model_name, col_name).as_deref() - == Some("binary") - { + if model_column_format(&self.model_name, col_name).as_deref() == Some("binary") { return Expr::value(sea_query::Value::String(Some(Box::new(s.clone())))) .cast_as("bytea"); } @@ -280,3 +291,53 @@ pub(crate) fn property_schema_is_uuid(col_info: &Value) -> bool { }); json_type == Some("string") && format == Some("uuid") } + +#[cfg(test)] +mod tests { + use super::QueryDef; + use crate::backend::BackendKind; + use sea_query::{PostgresQueryBuilder, Query, SqliteQueryBuilder}; + use serde_json::json; + + #[test] + fn uuid_rhs_cast_uses_explicit_backend_not_global_default() { + let query_def = QueryDef { + model_name: "Widget".to_string(), + where_clause: Vec::new(), + order_by: None, + limit: None, + offset: None, + m2m: None, + }; + + let postgres_rhs = query_def.value_rhs_simple_expr_for_backend( + "widget_id", + &json!("3f4c4ca7-a7e7-40d6-8d83-8f4ddf3285e6"), + true, + BackendKind::Postgres, + ); + let postgres_sql = Query::select() + .expr(postgres_rhs) + .to_string(PostgresQueryBuilder); + + assert!( + postgres_sql.contains("AS uuid"), + "unexpected SQL: {postgres_sql}" + ); + + let sqlite_rhs = query_def.value_rhs_simple_expr_for_backend( + "widget_id", + &json!("3f4c4ca7-a7e7-40d6-8d83-8f4ddf3285e6"), + true, + BackendKind::Sqlite, + ); + let sqlite_sql = Query::select() + .expr(sqlite_rhs) + .to_string(SqliteQueryBuilder); + + assert!( + !sqlite_sql.contains("AS uuid"), + "unexpected SQL: {sqlite_sql}" + ); + } +} diff --git a/src/schema.rs b/src/schema.rs index 5749984..5ba731d 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -3,13 +3,13 @@ //! This module handles converting Pydantic JSON schemas into Sea-Query //! table definitions and managing the model registry. -use crate::state::{engine_pool, sql_dialect, MODEL_REGISTRY, SqlDialect}; +use crate::backend::EngineHandle; +use crate::state::{MODEL_REGISTRY, SqlDialect, engine_handle}; use pyo3::prelude::*; use sea_query::{ - Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, PostgresQueryBuilder, SqliteQueryBuilder, - Table, + Alias, ColumnDef, ForeignKey, ForeignKeyAction, Index, PostgresQueryBuilder, + SqliteQueryBuilder, Table, }; -use sqlx::{Any, Pool}; use std::collections::HashSet; use std::sync::Arc; @@ -29,7 +29,7 @@ fn resolve_ref<'a>( fn property_json_type_and_format(col_info: &serde_json::Value) -> (Option<&str>, Option<&str>) { let top_type = col_info.get("type").and_then(|t| t.as_str()); let top_format = col_info.get("format").and_then(|f| f.as_str()); - if top_type.is_some() || top_format.is_some() { + if top_type.is_some() { return (top_type, top_format); } @@ -40,11 +40,33 @@ fn property_json_type_and_format(col_info: &serde_json::Value) -> (Option<&str>, continue; } let item_format = item.get("format").and_then(|f| f.as_str()); - return (item_type, item_format); + return (item_type, item_format.or(top_format)); } } - (None, None) + (None, top_format) +} + +fn column_bool_metadata( + raw_col_info: &serde_json::Value, + resolved_col_info: &serde_json::Value, + key: &str, +) -> Option { + raw_col_info + .get(key) + .or_else(|| resolved_col_info.get(key)) + .and_then(|value| value.as_bool()) +} + +fn column_object_metadata<'a>( + raw_col_info: &'a serde_json::Value, + resolved_col_info: &'a serde_json::Value, + key: &str, +) -> Option<&'a serde_json::Map> { + raw_col_info + .get(key) + .or_else(|| resolved_col_info.get(key)) + .and_then(|value| value.as_object()) } fn schema_dependencies(schema: &serde_json::Value) -> Vec { @@ -76,8 +98,10 @@ fn order_schemas_for_creation( let mut created = HashSet::new(); while !remaining.is_empty() { - let available_names: HashSet = - remaining.iter().map(|(name, _)| name.to_lowercase()).collect(); + let available_names: HashSet = remaining + .iter() + .map(|(name, _)| name.to_lowercase()) + .collect(); let mut progress = false; let mut index = 0; @@ -104,8 +128,11 @@ fn order_schemas_for_creation( ordered } -/// Maps a JSON schema type string to a Sea-Query `ColumnDef`. -pub fn json_type_to_sea_query(col_def: &mut ColumnDef, json_type: &str) { +fn json_type_to_sea_query_for_backend( + col_def: &mut ColumnDef, + json_type: &str, + backend: SqlDialect, +) { match json_type { "integer" => { col_def.integer(); @@ -117,7 +144,7 @@ pub fn json_type_to_sea_query(col_def: &mut ColumnDef, json_type: &str) { col_def.double(); } "boolean" => { - match sql_dialect() { + match backend { SqlDialect::Sqlite => { // SQLite stores booleans as integers. col_def.integer(); @@ -128,7 +155,7 @@ pub fn json_type_to_sea_query(col_def: &mut ColumnDef, json_type: &str) { } } "object" | "array" => { - col_def.text(); // SQLite stores JSON as text + col_def.json(); } _ => { col_def.string(); @@ -150,6 +177,7 @@ fn append_composite_unique_index_sqls( table_lower: &str, schema: &serde_json::Value, index_sqls: &mut Vec, + backend: SqlDialect, ) { let Some(groups) = schema .get("ferro_composite_uniques") @@ -180,7 +208,7 @@ fn append_composite_unique_index_sqls( for c in &cols { stmt.col(Alias::new(*c)); } - let sql = match sql_dialect() { + let sql = match backend { SqlDialect::Sqlite => stmt.to_string(SqliteQueryBuilder), SqlDialect::Postgres => stmt.to_string(PostgresQueryBuilder), }; @@ -188,6 +216,127 @@ fn append_composite_unique_index_sqls( } } +fn build_create_table_sqls( + name: &str, + schema: &serde_json::Value, + backend: SqlDialect, +) -> (String, Vec) { + let table_lower = name.to_lowercase(); + let mut table_stmt = Table::create() + .table(Alias::new(&table_lower)) + .if_not_exists() + .to_owned(); + + let mut index_sqls = Vec::new(); + + if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) { + for (col_name, raw_col_info) in properties { + let mut col_def = ColumnDef::new(Alias::new(col_name)); + let col_info = resolve_ref(schema, raw_col_info); + + let (json_type, format) = property_json_type_and_format(col_info); + + if let Some(t) = json_type { + match (t, format) { + ("string", Some("date-time")) => { + col_def.timestamp_with_time_zone(); + } + ("string", Some("date")) => { + col_def.date(); + } + ("string", Some("uuid")) => { + col_def.uuid(); + } + (_, Some("decimal")) => { + col_def.decimal(); + } + ("string", Some("binary")) => { + col_def.blob(); + } + _ => json_type_to_sea_query_for_backend(&mut col_def, t, backend), + } + } else { + col_def.string(); + } + + // Check for primary key and autoincrement from our custom metadata + let is_pk = + column_bool_metadata(raw_col_info, col_info, "primary_key").unwrap_or(false); + + let is_auto = + column_bool_metadata(raw_col_info, col_info, "autoincrement").unwrap_or(true); + + if is_pk { + col_def.primary_key(); + if is_auto { + col_def.auto_increment(); + } + } + + if column_bool_metadata(raw_col_info, col_info, "ferro_nullable") == Some(false) { + col_def.not_null(); + } + + if column_bool_metadata(raw_col_info, col_info, "unique").unwrap_or(false) { + col_def.unique_key(); + } + + if column_bool_metadata(raw_col_info, col_info, "index").unwrap_or(false) { + let index_name = format!("idx_{}_{}", table_lower, col_name); + let index_stmt = Index::create() + .name(&index_name) + .table(Alias::new(&table_lower)) + .col(Alias::new(col_name)) + .if_not_exists() + .to_owned(); + let index_sql = match backend { + SqlDialect::Sqlite => index_stmt.to_string(SqliteQueryBuilder), + SqlDialect::Postgres => index_stmt.to_string(PostgresQueryBuilder), + }; + index_sqls.push(index_sql); + } + + table_stmt.col(&mut col_def); + + // Check for Foreign Key from metadata + if let Some(fk_info) = column_object_metadata(raw_col_info, col_info, "foreign_key") { + let to_table = fk_info + .get("to_table") + .and_then(|t| t.as_str()) + .unwrap_or(""); + let on_delete_str = fk_info + .get("on_delete") + .and_then(|o| o.as_str()) + .unwrap_or("CASCADE"); + + let action = match on_delete_str.to_uppercase().as_str() { + "RESTRICT" => ForeignKeyAction::Restrict, + "SET NULL" => ForeignKeyAction::SetNull, + "SET DEFAULT" => ForeignKeyAction::SetDefault, + "NO ACTION" => ForeignKeyAction::NoAction, + _ => ForeignKeyAction::Cascade, // Default + }; + + let mut fk_stmt = ForeignKey::create(); + fk_stmt + .from(Alias::new(&table_lower), Alias::new(col_name)) + .to(Alias::new(to_table), Alias::new("id")) // CX Choice: Assume target PK is 'id' for now + .on_delete(action); + + table_stmt.foreign_key(&mut fk_stmt); + } + } + } + + append_composite_unique_index_sqls(&table_lower, schema, &mut index_sqls, backend); + + let table_sql = match backend { + SqlDialect::Sqlite => table_stmt.build(SqliteQueryBuilder), + SqlDialect::Postgres => table_stmt.build(PostgresQueryBuilder), + }; + (table_sql, index_sqls) +} + /// Internal utility to create all registered tables in the database. /// /// This is used by both the `connect(auto_migrate=True)` flow and the @@ -195,7 +344,7 @@ fn append_composite_unique_index_sqls( /// /// # Errors /// Returns a `PyErr` if the SQL execution fails. -pub async fn internal_create_tables(pool: Arc>) -> PyResult<()> { +pub async fn internal_create_tables(engine: Arc) -> PyResult<()> { let schemas = { let registry = MODEL_REGISTRY.read().map_err(|_| { pyo3::exceptions::PyRuntimeError::new_err("Failed to lock Model Registry") @@ -203,143 +352,12 @@ pub async fn internal_create_tables(pool: Arc>) -> PyResult<()> { registry.clone() }; - let mut conn = pool.acquire().await.map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to acquire connection: {}", e)) - })?; + let backend = engine.backend(); for (name, schema) in order_schemas_for_creation(schemas) { - let (sql, index_sqls) = { - let table_lower = name.to_lowercase(); - let mut table_stmt = Table::create() - .table(Alias::new(&table_lower)) - .if_not_exists() - .to_owned(); - - let mut index_sqls = Vec::new(); - - if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) { - for (col_name, col_info) in properties { - let mut col_def = ColumnDef::new(Alias::new(col_name)); - let col_info = resolve_ref(&schema, col_info); - - let (json_type, format) = property_json_type_and_format(col_info); - - if let Some(t) = json_type { - match (t, format) { - ("string", Some("date-time")) => { - col_def.timestamp_with_time_zone(); - } - ("string", Some("date")) => { - col_def.date(); - } - ("string", Some("uuid")) => { - col_def.uuid(); - } - ("string", Some("binary")) => { - col_def.blob(); - } - _ => json_type_to_sea_query(&mut col_def, t), - } - } else { - col_def.string(); - } - - // Check for primary key and autoincrement from our custom metadata - let is_pk = col_info - .get("primary_key") - .and_then(|pk| pk.as_bool()) - .unwrap_or(false); - - let is_auto = col_info - .get("autoincrement") - .and_then(|auto| auto.as_bool()) - .unwrap_or(true); - - if is_pk { - col_def.primary_key(); - if is_auto { - col_def.auto_increment(); - } - } - - if col_info - .get("ferro_nullable") - .and_then(|nullable| nullable.as_bool()) - == Some(false) - { - col_def.not_null(); - } - - if col_info - .get("unique") - .and_then(|u| u.as_bool()) - .unwrap_or(false) - { - col_def.unique_key(); - } - - if col_info - .get("index") - .and_then(|i| i.as_bool()) - .unwrap_or(false) - { - let index_name = format!("idx_{}_{}", table_lower, col_name); - let index_stmt = Index::create() - .name(&index_name) - .table(Alias::new(&table_lower)) - .col(Alias::new(col_name)) - .if_not_exists() - .to_owned(); - let index_sql = match sql_dialect() { - SqlDialect::Sqlite => index_stmt.to_string(SqliteQueryBuilder), - SqlDialect::Postgres => index_stmt.to_string(PostgresQueryBuilder), - }; - index_sqls.push(index_sql); - } - - table_stmt.col(&mut col_def); - - // Check for Foreign Key from metadata - if let Some(fk_info) = col_info.get("foreign_key").and_then(|fk| fk.as_object()) - { - let to_table = fk_info - .get("to_table") - .and_then(|t| t.as_str()) - .unwrap_or(""); - let on_delete_str = fk_info - .get("on_delete") - .and_then(|o| o.as_str()) - .unwrap_or("CASCADE"); - - let action = match on_delete_str.to_uppercase().as_str() { - "RESTRICT" => ForeignKeyAction::Restrict, - "SET NULL" => ForeignKeyAction::SetNull, - "SET DEFAULT" => ForeignKeyAction::SetDefault, - "NO ACTION" => ForeignKeyAction::NoAction, - _ => ForeignKeyAction::Cascade, // Default - }; - - let mut fk_stmt = ForeignKey::create(); - fk_stmt - .from(Alias::new(&table_lower), Alias::new(col_name)) - .to(Alias::new(to_table), Alias::new("id")) // CX Choice: Assume target PK is 'id' for now - .on_delete(action); - - table_stmt.foreign_key(&mut fk_stmt); - } - } - } - - append_composite_unique_index_sqls(&table_lower, &schema, &mut index_sqls); - - let table_sql = match sql_dialect() { - SqlDialect::Sqlite => table_stmt.build(SqliteQueryBuilder), - SqlDialect::Postgres => table_stmt.build(PostgresQueryBuilder), - }; - (table_sql, index_sqls) - }; + let (sql, index_sqls) = build_create_table_sqls(&name, &schema, backend); - sqlx::query(&sql).execute(&mut *conn).await.map_err(|e| { + engine.execute_sql(&sql).await.map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!( "SQL Execution failed for '{}' table: {}", name, e @@ -347,15 +365,12 @@ pub async fn internal_create_tables(pool: Arc>) -> PyResult<()> { })?; for index_sql in index_sqls { - sqlx::query(&index_sql) - .execute(&mut *conn) - .await - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!( - "SQL Execution failed for '{}' index: {}", - name, e - )) - })?; + engine.execute_sql(&index_sql).await.map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "SQL Execution failed for '{}' index: {}", + name, e + )) + })?; } crate::log_debug(format!("✅ Ferro Engine: Table '{}' created", name)); @@ -396,12 +411,12 @@ pub fn register_model_schema(name: String, schema: String) -> PyResult<()> { #[pyfunction] pub fn create_tables(py: Python<'_>) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - let pool = engine_pool().ok_or_else(|| { + let engine = engine_handle().ok_or_else(|| { pyo3::exceptions::PyRuntimeError::new_err( "Engine not initialized. Call connect() first.", ) })?; - internal_create_tables(pool).await + internal_create_tables(engine).await }) } diff --git a/src/state.rs b/src/state.rs index b38bfe5..6425a19 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,12 +3,11 @@ //! This module holds the global connection pool, the model registry, //! and the Identity Map used for object tracking. -use crate::backend::{BackendKind, EngineHandle}; +use crate::backend::{BackendKind, EngineConnection, EngineHandle}; use dashmap::DashMap; use once_cell::sync::Lazy; use pyo3::IntoPyObjectExt; use pyo3::prelude::*; -use sqlx::{Any, AnyConnection, Pool}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use tokio::sync::Mutex; @@ -16,16 +15,6 @@ use tokio::sync::Mutex; /// Backward-compatible name for query/DDL builder selection. pub type SqlDialect = BackendKind; -/// Returns the dialect selected at [`crate::connection::connect`] time. -#[inline] -pub fn sql_dialect() -> SqlDialect { - ENGINE - .read() - .ok() - .and_then(|engine| engine.as_ref().map(|engine| engine.backend())) - .unwrap_or_default() -} - /// Global registry mapping model names to their Pydantic-generated JSON schemas. pub static MODEL_REGISTRY: Lazy>> = Lazy::new(|| RwLock::new(HashMap::new())); @@ -33,30 +22,29 @@ pub static MODEL_REGISTRY: Lazy>> = /// The global runtime engine, initialized via `connect()`. pub static ENGINE: Lazy>>> = Lazy::new(|| RwLock::new(None)); -/// Returns a clone of the active pool when the engine is initialized. -pub fn engine_pool() -> Option>> { - ENGINE - .read() - .ok() - .and_then(|engine| engine.as_ref().map(|engine| engine.pool())) +/// Returns a clone of the active engine handle when initialized. +pub fn engine_handle() -> Option> { + ENGINE.read().ok().and_then(|engine| engine.clone()) } /// Active transaction handle. #[derive(Clone)] pub struct TransactionHandle { - pub conn: Arc>, + pub conn: TransactionConnection, pub savepoint_name: Option, } +pub type TransactionConnection = Arc>; + impl TransactionHandle { - pub fn root(conn: AnyConnection) -> Self { + pub fn root(conn: EngineConnection) -> Self { Self { conn: Arc::new(Mutex::new(conn)), savepoint_name: None, } } - pub fn nested(conn: Arc>, savepoint_name: String) -> Self { + pub fn nested(conn: TransactionConnection, savepoint_name: String) -> Self { Self { conn, savepoint_name: Some(savepoint_name), @@ -66,8 +54,7 @@ impl TransactionHandle { /// Global registry for active transactions. /// Maps Transaction ID -> backend connection plus optional savepoint. -pub static TRANSACTION_REGISTRY: Lazy> = - Lazy::new(DashMap::new); +pub static TRANSACTION_REGISTRY: Lazy> = Lazy::new(DashMap::new); /// Identity Map used for object tracking and deduplication. /// diff --git a/tests/conftest.py b/tests/conftest.py index 2c9c88f..016df13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,8 +8,10 @@ from ferro import version from tests.db_backends import ( backends_for_test, + build_postgres_url_from_connection_params, build_postgres_test_url, - get_supabase_url, + get_postgres_url, + has_pytest_postgresql, parse_backend_option, ) @@ -49,7 +51,11 @@ def _selected_backends(config: pytest.Config) -> tuple[str, ...]: def _available_postgres_url() -> str | None: - return get_supabase_url(dict(os.environ), ENV_FILE) + return get_postgres_url(dict(os.environ), ENV_FILE) + + +def _has_postgres_provider() -> bool: + return bool(_available_postgres_url()) or has_pytest_postgresql() def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: @@ -63,7 +69,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: is not None, is_sqlite_only=metafunc.definition.get_closest_marker("sqlite_only") is not None, is_postgres_only=metafunc.definition.get_closest_marker("postgres_only") is not None, - has_postgres_url=bool(_available_postgres_url()), + has_postgres_url=_has_postgres_provider(), ) if not test_backends: @@ -73,7 +79,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: pytest.param( None, marks=pytest.mark.skip( - reason="FERRO_SUPABASE_URL is not configured for Postgres-backed tests.", + reason="No Postgres provider is configured for Postgres-backed tests.", ), ) ], @@ -102,6 +108,19 @@ def _drop_postgres_schema(base_url: str, schema_name: str) -> None: conn.execute(f'DROP SCHEMA IF EXISTS "{schema_name}" CASCADE') +def _pytest_postgresql_base_url(request: pytest.FixtureRequest) -> str: + try: + conn = request.getfixturevalue("postgresql") + except Exception as exc: + pytest.skip( + "pytest-postgresql could not start local Postgres. " + "Install Postgres server binaries or set FERRO_POSTGRES_URL. " + f"Original error: {exc}" + ) + + return build_postgres_url_from_connection_params(conn.info.get_parameters()) + + # This fixture ensures the Rust binary is actually loaded and working @pytest.fixture(scope="session", autouse=True) def check_engine(): @@ -131,13 +150,14 @@ def db_url(request: pytest.FixtureRequest, tmp_path: Path): yield f"sqlite:{db_file}?mode=rwc" return - base_url = _available_postgres_url() + base_url = _available_postgres_url() or _pytest_postgresql_base_url(request) if not base_url: - pytest.skip("FERRO_SUPABASE_URL is not configured for Postgres-backed tests.") + pytest.skip("No Postgres provider is configured for Postgres-backed tests.") schema_name = f"ferro_{uuid.uuid4().hex[:16]}" _create_postgres_schema(base_url, schema_name) request.node._ferro_db_schema = schema_name + request.node._ferro_postgres_base_url = base_url try: yield build_postgres_test_url(base_url, schema_name) @@ -148,13 +168,21 @@ def db_url(request: pytest.FixtureRequest, tmp_path: Path): _drop_postgres_schema(base_url, schema_name) -@pytest.fixture(scope="session") -def postgres_base_url() -> str | None: - return _available_postgres_url() +@pytest.fixture(scope="function") +def postgres_base_url(request: pytest.FixtureRequest, db_url: str | None) -> str | None: + if db_url is None or not ( + db_url.startswith("postgres://") or db_url.startswith("postgresql://") + ): + return None + return getattr(request.node, "_ferro_postgres_base_url", _available_postgres_url()) @pytest.fixture(scope="function") -def db_schema_name(request: pytest.FixtureRequest) -> str | None: +def db_schema_name(request: pytest.FixtureRequest, db_url: str | None) -> str | None: + if db_url is None or not ( + db_url.startswith("postgres://") or db_url.startswith("postgresql://") + ): + return None return getattr(request.node, "_ferro_db_schema", None) diff --git a/tests/db_backends.py b/tests/db_backends.py index 6586d86..5b67f4d 100644 --- a/tests/db_backends.py +++ b/tests/db_backends.py @@ -1,10 +1,12 @@ from __future__ import annotations +import importlib.util from pathlib import Path -from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse +from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse SUPPORTED_BACKENDS = ("sqlite", "postgres") +LOCAL_POSTGRES_PROVIDER = "local" def load_env_value(env_file: Path, key: str) -> str | None: @@ -25,8 +27,26 @@ def load_env_value(env_file: Path, key: str) -> str | None: return None +def get_postgres_url(env: dict[str, str], env_file: Path) -> str | None: + """Return an externally managed Postgres URL, if one is configured.""" + if env.get("FERRO_POSTGRES_PROVIDER") == LOCAL_POSTGRES_PROVIDER: + return None + + return ( + env.get("FERRO_POSTGRES_URL") + or load_env_value(env_file, "FERRO_POSTGRES_URL") + or env.get("FERRO_SUPABASE_URL") + or load_env_value(env_file, "FERRO_SUPABASE_URL") + ) + + def get_supabase_url(env: dict[str, str], env_file: Path) -> str | None: - return env.get("FERRO_SUPABASE_URL") or load_env_value(env_file, "FERRO_SUPABASE_URL") + """Backward-compatible alias for the old Supabase-only test setting.""" + return get_postgres_url(env, env_file) + + +def has_pytest_postgresql() -> bool: + return importlib.util.find_spec("pytest_postgresql") is not None def parse_backend_option(raw_value: str) -> tuple[str, ...]: @@ -65,3 +85,31 @@ def build_postgres_test_url(base_url: str, schema_name: str) -> str: params = parse_qsl(parsed.query, keep_blank_values=True) params.append(("ferro_search_path", schema_name)) return urlunparse(parsed._replace(query=urlencode(params))) + + +def build_postgres_url_from_connection_params(params: dict[str, str]) -> str: + dbname = params.get("dbname") or params.get("database") or "postgres" + host = params.get("host") or "localhost" + port = params.get("port") + user = params.get("user") + password = params.get("password") + + userinfo = "" + if user: + userinfo = quote(user, safe="") + if password: + userinfo += f":{quote(password, safe='')}" + userinfo += "@" + + # libpq can report a Unix socket path as host; keep the URL TCP-shaped and + # pass the socket through the query string for psycopg/sqlx compatibility. + query = "" + if host.startswith("/"): + query = urlencode({"host": host}) + host = "localhost" + + netloc = f"{userinfo}{host}" + if port: + netloc += f":{port}" + + return urlunparse(("postgresql", netloc, f"/{quote(dbname, safe='')}", "", query, "")) diff --git a/tests/test_alembic_bridge.py b/tests/test_alembic_bridge.py index a5763f4..98a7334 100644 --- a/tests/test_alembic_bridge.py +++ b/tests/test_alembic_bridge.py @@ -135,6 +135,7 @@ class UuidMember(Model): assert isinstance(col.type, sa.Uuid) or ( isinstance(col.type, sa.String) and getattr(col.type, "length", None) == 36 ) + assert col.nullable is False def test_uuid_foreign_key_shadow_column_type(): diff --git a/tests/test_alembic_nullability.py b/tests/test_alembic_nullability.py index 5d10fd5..08ca662 100644 --- a/tests/test_alembic_nullability.py +++ b/tests/test_alembic_nullability.py @@ -44,6 +44,7 @@ class Row(Model): id: Annotated[int, FerroField(primary_key=True)] field_a: int = Field(default=0) + assert Row.__ferro_schema__["properties"]["field_a"]["ferro_nullable"] is False t = get_metadata().tables["row"] assert t.c.field_a.nullable is False @@ -53,6 +54,7 @@ class Row(Model): id: Annotated[int, FerroField(primary_key=True)] field_a: int | None = None + assert Row.__ferro_schema__["properties"]["field_a"]["ferro_nullable"] is True t = get_metadata().tables["row"] assert t.c.field_a.nullable is True @@ -131,6 +133,7 @@ class Row(Model): id: Annotated[int, FerroField(primary_key=True)] status: Status = Status.DRAFT + assert Row.__ferro_schema__["properties"]["status"]["ferro_nullable"] is False t = get_metadata().tables["row"] assert t.c.status.nullable is False @@ -145,6 +148,7 @@ class ChildReq(Model): id: Annotated[int, FerroField(primary_key=True)] parent: Annotated[Parent, ForeignKey(related_name="children")] + assert ChildReq.__ferro_schema__["properties"]["parent_id"]["ferro_nullable"] is False t = get_metadata().tables["childreq"] assert t.c.parent_id.nullable is False @@ -179,6 +183,7 @@ class ChildOpt(Model): id: Annotated[int, FerroField(primary_key=True)] parent: Annotated[Parent | None, ForeignKey(related_name="children")] = None + assert ChildOpt.__ferro_schema__["properties"]["parent_id"]["ferro_nullable"] is True t = get_metadata().tables["childopt"] assert t.c.parent_id.nullable is True diff --git a/tests/test_alembic_type_mapping.py b/tests/test_alembic_type_mapping.py index ba606f8..edfb873 100644 --- a/tests/test_alembic_type_mapping.py +++ b/tests/test_alembic_type_mapping.py @@ -47,7 +47,8 @@ class ComplexModel(Model): assert set(table.c.status.type.enums) == {"active", "inactive"} # Numeric/Decimal - assert isinstance(table.c.price.type, (sa.Numeric, sa.Float)) + assert ComplexModel.__ferro_schema__["properties"]["price"]["format"] == "decimal" + assert isinstance(table.c.price.type, sa.Numeric) # UUID assert isinstance(table.c.token.type, (sa.Uuid, sa.String)) diff --git a/tests/test_auto_migrate.py b/tests/test_auto_migrate.py index 4d39f9c..9589d5c 100644 --- a/tests/test_auto_migrate.py +++ b/tests/test_auto_migrate.py @@ -1,4 +1,5 @@ from typing import Annotated +from uuid import UUID, uuid4 import pytest from pydantic import Field @@ -78,3 +79,111 @@ class Movie(Model): assert len(linked) == 1 assert linked[0].id == movie.id assert linked[0].title == "Matrix" + assert await actor.movies.count() == 1 + + reverse_linked = await movie.actors.all() + assert [row.id for row in reverse_linked] == [actor.id] + + await actor.movies.remove(movie) + assert await actor.movies.count() == 0 + + movie_2 = await Movie.create(title="Reloaded") + await actor.movies.add(movie, movie_2) + assert await actor.movies.count() == 2 + await actor.movies.clear() + assert await actor.movies.count() == 0 + + +@pytest.mark.asyncio +@pytest.mark.sqlite_only +async def test_uuid_m2m_join_table_columns_inherit_pk_type_and_nullability(db_url): + """Runtime join-table DDL should derive FK column metadata from source PKs.""" + from ferro import clear_registry, connect, reset_engine + from ferro.state import _JOIN_TABLE_REGISTRY, _MODEL_REGISTRY_PY, _PENDING_RELATIONS + + reset_engine() + clear_registry() + _MODEL_REGISTRY_PY.clear() + _PENDING_RELATIONS.clear() + _JOIN_TABLE_REGISTRY.clear() + + class UuidActor(Model): + id: Annotated[UUID, FerroField(primary_key=True)] = Field(default_factory=uuid4) + name: str + movies: Annotated[list["UuidMovie"], ManyToManyField(related_name="actors")] = None + + class UuidMovie(Model): + id: Annotated[UUID, FerroField(primary_key=True)] = Field(default_factory=uuid4) + title: str + actors: BackRef[UuidActor] = None + + await connect(db_url, auto_migrate=True) + + import sqlite3 + + db_path = db_url.removeprefix("sqlite:").split("?", 1)[0] + conn = sqlite3.connect(db_path) + rows = conn.execute("PRAGMA table_info(uuidactor_movies)").fetchall() + conn.close() + + columns = {row[1]: row for row in rows} + assert columns["uuidactor_id"][2].upper() in {"UUID", "UUID_TEXT", "TEXT", "CHAR", "VARCHAR"} + assert columns["uuidmovie_id"][2].upper() in {"UUID", "UUID_TEXT", "TEXT", "CHAR", "VARCHAR"} + assert columns["uuidactor_id"][3] == 1 + assert columns["uuidmovie_id"][3] == 1 + + +@pytest.mark.asyncio +async def test_uuid_m2m_relationship_query_serializes_source_id(db_url): + """UUID source PKs in M2M contexts should serialize for all query operations.""" + from ferro import Field as FerroFieldFn + from ferro import clear_registry, connect, reset_engine + from ferro.models import transaction + from ferro.state import _JOIN_TABLE_REGISTRY, _MODEL_REGISTRY_PY, _PENDING_RELATIONS + + reset_engine() + clear_registry() + _MODEL_REGISTRY_PY.clear() + _PENDING_RELATIONS.clear() + _JOIN_TABLE_REGISTRY.clear() + + class UuidTag(Model): + id: UUID = FerroFieldFn(default_factory=uuid4, primary_key=True) + name: str = "" + posts: BackRef[list["UuidPost"]] | None = None + + class UuidPost(Model): + id: UUID = FerroFieldFn(default_factory=uuid4, primary_key=True) + title: str = "" + tags: Annotated[list[UuidTag], ManyToManyField(related_name="posts")] = None + + await connect(db_url, auto_migrate=True) + + post = await UuidPost.create(title="Hello") + tag = await UuidTag.create(name="python") + + await post.tags.add(tag) + + linked = await post.tags.all() + assert [row.id for row in linked] == [tag.id] + assert await post.tags.count() == 1 + + reverse_linked = await tag.posts.all() + assert [row.id for row in reverse_linked] == [post.id] + + await post.tags.remove(tag) + assert await post.tags.count() == 0 + + tag_2 = await UuidTag.create(name="orm") + await post.tags.add(tag, tag_2) + assert await post.tags.count() == 2 + await post.tags.clear() + assert await post.tags.count() == 0 + + async with transaction(): + await post.tags.add(tag) + assert await post.tags.count() == 1 + await post.tags.remove(tag) + assert await post.tags.count() == 0 + + assert await post.tags.count() == 0 diff --git a/tests/test_db_backends.py b/tests/test_db_backends.py index c967a03..4064705 100644 --- a/tests/test_db_backends.py +++ b/tests/test_db_backends.py @@ -34,6 +34,41 @@ def test_get_supabase_url_prefers_environment_over_dotenv(tmp_path: Path): ) +def test_get_postgres_url_prefers_generic_setting_over_supabase(tmp_path: Path): + env_file = tmp_path / ".env" + env_file.write_text( + "FERRO_SUPABASE_URL=postgresql://dotenv-supabase.example/postgres\n", + encoding="utf-8", + ) + + assert ( + db_backends.get_postgres_url( + { + "FERRO_POSTGRES_URL": "postgresql://generic.example/postgres", + "FERRO_SUPABASE_URL": "postgresql://env-supabase.example/postgres", + }, + env_file, + ) + == "postgresql://generic.example/postgres" + ) + + +def test_get_postgres_url_can_force_local_provider(tmp_path: Path): + env_file = tmp_path / ".env" + env_file.write_text( + "FERRO_POSTGRES_URL=postgresql://dotenv.example/postgres\n", + encoding="utf-8", + ) + + assert ( + db_backends.get_postgres_url( + {"FERRO_POSTGRES_PROVIDER": "local"}, + env_file, + ) + is None + ) + + def test_parse_backend_option_validates_backend_names(): assert db_backends.parse_backend_option("sqlite,postgres") == ("sqlite", "postgres") @@ -75,3 +110,17 @@ def test_build_postgres_test_url_sets_search_path(): assert "sslmode=require" in url assert "ferro_search_path=ferro_test_schema" in url + + +def test_build_postgres_url_from_connection_params(): + url = db_backends.build_postgres_url_from_connection_params( + { + "host": "127.0.0.1", + "port": "55432", + "user": "postgres", + "password": "secret value", + "dbname": "test_db", + } + ) + + assert url == "postgresql://postgres:secret%20value@127.0.0.1:55432/test_db" diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index 04b6888..4d89cbf 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -1,11 +1,84 @@ +import json +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal +from enum import Enum + import pytest from ferro import Model, connect -from pydantic import Field from ferro.query import Query, QueryNode +from ferro.query.builder import _query_def_to_json +from ferro.query.nodes import _serialize_query_value +from pydantic import Field pytestmark = pytest.mark.backend_matrix +class QueryStatus(str, Enum): + ACTIVE = "active" + + +def test_serialize_query_value_normalizes_non_json_native_values(): + uid = uuid.uuid4() + happened_at = datetime(2026, 4, 24, 18, 30, tzinfo=UTC) + payload = { + "id": uid, + "price": Decimal("12.50"), + "happened_at": happened_at, + "day": date(2026, 4, 24), + "status": QueryStatus.ACTIVE, + "nested": { + "ids": [uid], + "amounts": (Decimal("1.25"),), + "unique_ids": {uid}, + }, + } + + serialized = _serialize_query_value(payload) + + assert serialized["id"] == str(uid) + assert serialized["price"] == "12.50" + assert serialized["happened_at"] == happened_at.isoformat() + assert serialized["day"] == "2026-04-24" + assert serialized["status"] == QueryStatus.ACTIVE + assert serialized["nested"]["ids"] == [str(uid)] + assert serialized["nested"]["amounts"] == ["1.25"] + assert serialized["nested"]["unique_ids"] == [str(uid)] + json.dumps(serialized) + + +def test_query_def_to_json_serializes_m2m_context_without_mutating_query_state(): + source_id = uuid.uuid4() + query = Query(Model)._m2m( + "post_tags", + "post_id", + "tag_id", + source_id, + ) + query_def = { + "model_name": "Tag", + "where_clause": [], + "order_by": [], + "limit": None, + "offset": None, + "m2m": query._m2m_context, + } + + query_json = _query_def_to_json(query_def) + + assert query._m2m_context["source_id"] == source_id + assert isinstance(query._m2m_context["source_id"], uuid.UUID) + assert json.loads(query_json)["m2m"]["source_id"] == str(source_id) + + +def test_query_node_to_dict_serializes_uuid_values_inside_in_filters(): + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + node = QueryNode(column="run_id", operator="IN", value=[uid1, uid2]) + + assert node.to_dict()["value"] == [str(uid1), str(uid2)] + + def test_field_proxy_operator_overloading(): """ Test that accessing a field on the Model class returns a FieldProxy diff --git a/tests/test_schema.py b/tests/test_schema.py index a81a56c..0275a04 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,7 +1,11 @@ import pytest import ferro -from ferro import Model +from ferro import FerroField, Model from pydantic import Field +from typing import Annotated +from enum import StrEnum +from decimal import Decimal +from typing import Dict, List pytestmark = pytest.mark.backend_matrix @@ -34,3 +38,84 @@ async def test_create_tables_no_connection(): with pytest.raises(RuntimeError) as excinfo: await ferro.create_tables() assert "Engine not initialized" in str(excinfo.value) + + +@pytest.mark.asyncio +@pytest.mark.sqlite_only +async def test_auto_migrate_runtime_ddl_infers_required_field_not_null(db_url): + """Runtime DDL should use the same nullability metadata as Alembic.""" + + class NullabilityRow(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + required_name: str + optional_note: str | None = None + + await ferro.connect(db_url, auto_migrate=True) + + import sqlite3 + + db_path = db_url.removeprefix("sqlite:").split("?", 1)[0] + conn = sqlite3.connect(db_path) + rows = conn.execute("PRAGMA table_info(nullabilityrow)").fetchall() + conn.close() + + not_null_by_column = {row[1]: row[3] for row in rows} + assert not_null_by_column["required_name"] == 1 + assert not_null_by_column["optional_note"] == 0 + + +@pytest.mark.asyncio +@pytest.mark.sqlite_only +async def test_auto_migrate_runtime_ddl_preserves_ref_field_nullability(db_url): + """Runtime DDL should not lose Ferro metadata when resolving JSON-schema refs.""" + + class RowStatus(StrEnum): + DRAFT = "draft" + ACTIVE = "active" + + class RefNullabilityRow(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + status: RowStatus = RowStatus.DRAFT + + await ferro.connect(db_url, auto_migrate=True) + + import sqlite3 + + db_path = db_url.removeprefix("sqlite:").split("?", 1)[0] + conn = sqlite3.connect(db_path) + rows = conn.execute("PRAGMA table_info(refnullabilityrow)").fetchall() + conn.close() + + not_null_by_column = {row[1]: row[3] for row in rows} + assert not_null_by_column["status"] == 1 + + +@pytest.mark.asyncio +@pytest.mark.sqlite_only +async def test_auto_migrate_runtime_ddl_uses_logical_decimal_and_json_types(db_url): + """Runtime DDL should preserve Decimal/JSON logical type intent.""" + + class LogicalTypeRow(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + price: Decimal + metadata: Dict[str, str] + tags: List[str] + + props = LogicalTypeRow.__ferro_schema__["properties"] + assert props["price"]["format"] == "decimal" + assert props["metadata"]["type"] == "object" + assert props["tags"]["type"] == "array" + + await ferro.connect(db_url, auto_migrate=True) + + import sqlite3 + + db_path = db_url.removeprefix("sqlite:").split("?", 1)[0] + conn = sqlite3.connect(db_path) + rows = conn.execute("PRAGMA table_info(logicaltyperow)").fetchall() + conn.close() + + types_by_column = {row[1]: row[2].upper() for row in rows} + assert types_by_column["price"] == "REAL" + assert types_by_column["metadata"] == "JSON_TEXT" + assert types_by_column["tags"] == "JSON_TEXT" diff --git a/tests/test_static_contracts.py b/tests/test_static_contracts.py new file mode 100644 index 0000000..c598400 --- /dev/null +++ b/tests/test_static_contracts.py @@ -0,0 +1,8 @@ +from pathlib import Path + + +def test_query_methods_use_query_def_serializer_instead_of_raw_json_dumps(): + source = Path("src/ferro/query/builder.py").read_text(encoding="utf-8") + + assert source.count("json.dumps(_serialize_query_value(query_def))") == 1 + assert "json.dumps(query_def)" not in source diff --git a/tests/test_structural_types.py b/tests/test_structural_types.py index 3809b5c..f148dac 100644 --- a/tests/test_structural_types.py +++ b/tests/test_structural_types.py @@ -128,6 +128,76 @@ class ComplexModel(Model): assert {row.user_id for row in results} == {uid1, uid3} +@pytest.mark.asyncio +async def test_uuid_filter_serializes_for_update_and_delete_queries(db_url): + """UUID filters should serialize for mutating query payloads too.""" + + class UuidMutationModel(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + run_id: uuid.UUID + label: str + + await connect(db_url, auto_migrate=True) + + uid1 = uuid.uuid4() + uid2 = uuid.uuid4() + await UuidMutationModel.create(run_id=uid1, label="old") + await UuidMutationModel.create(run_id=uid2, label="keep") + + updated = await UuidMutationModel.where( + UuidMutationModel.run_id == uid1 + ).update(label="new") + assert updated == 1 + + fetched = await UuidMutationModel.where(UuidMutationModel.run_id == uid1).first() + assert fetched is not None + assert fetched.label == "new" + + deleted = await UuidMutationModel.where(UuidMutationModel.run_id == uid1).delete() + assert deleted == 1 + remaining = await UuidMutationModel.all() + assert [row.run_id for row in remaining] == [uid2] + + +@pytest.mark.asyncio +@pytest.mark.postgres_only +async def test_postgres_json_and_decimal_updates_keep_typed_hydration(db_url): + """Postgres JSON and numeric updates need casts plus text hydration.""" + + class PgTypedMutation(Model): + id: Annotated[int | None, FerroField(primary_key=True)] = None + metadata: Dict[str, str] + tags: List[str] + balance: Decimal + + await connect(db_url, auto_migrate=True) + + row = await PgTypedMutation.create( + metadata={"old": "value"}, + tags=["a"], + balance=Decimal("1.50"), + ) + + updated = await PgTypedMutation.where(PgTypedMutation.id == row.id).update( + metadata={"new": "value"}, + tags=["b", "c"], + balance=Decimal("2.75"), + ) + assert updated == 1 + + fetched = await PgTypedMutation.get(row.id) + assert fetched is not None + assert fetched.metadata == {"new": "value"} + assert fetched.tags == ["b", "c"] + assert fetched.balance == Decimal("2.75") + + filtered = await PgTypedMutation.where( + PgTypedMutation.balance > Decimal("2.00") + ).first() + assert filtered is not None + assert filtered.id == row.id + + @pytest.mark.asyncio @pytest.mark.postgres_only async def test_native_postgres_enum_column_decodes_via_text_cast( diff --git a/uv.lock b/uv.lock index 8af9e2a..cbb50b4 100644 --- a/uv.lock +++ b/uv.lock @@ -528,7 +528,7 @@ wheels = [ [[package]] name = "ferro-orm" -version = "0.3.2" +version = "0.3.3" source = { editable = "." } dependencies = [ { name = "pydantic" }, @@ -551,6 +551,7 @@ ci-test = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-examples" }, + { name = "pytest-postgresql" }, ] dev = [ { name = "alembic" }, @@ -567,6 +568,7 @@ dev = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-examples" }, + { name = "pytest-postgresql" }, { name = "python-semantic-release" }, { name = "rich" }, { name = "sqlalchemy" }, @@ -599,6 +601,7 @@ ci-test = [ { name = "pytest-asyncio", specifier = ">=0.23.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-examples", specifier = ">=0.0.18" }, + { name = "pytest-postgresql", specifier = ">=8.0.0" }, ] dev = [ { name = "alembic", specifier = ">=1.18.1" }, @@ -615,6 +618,7 @@ dev = [ { name = "pytest-asyncio", specifier = ">=0.23.0" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-examples", specifier = ">=0.0.18" }, + { name = "pytest-postgresql", specifier = ">=8.0.0" }, { name = "python-semantic-release", specifier = ">=9.0.0" }, { name = "rich", specifier = ">=14.2.0" }, { name = "sqlalchemy", specifier = ">=2.0.46" }, @@ -1330,6 +1334,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/19/04f9b178c2d8a15b076c8b5140708fa6ffc5601fb6f1e975537072df5b2a/mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307", size = 6354, upload-time = "2021-02-05T18:55:29.583Z" }, ] +[[package]] +name = "mirakuru" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "psutil", marker = "sys_platform != 'cygwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/53/23/db9034ba28c7d89a540ffb8ca789f70dc12079108ece1cd1762295d5c807/mirakuru-3.0.2.tar.gz", hash = "sha256:21192186a8680ea7567ca68170261df3785768b12962dd19fe8cccab15ad3441", size = 29338, upload-time = "2026-02-11T19:41:15.42Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/5f/a3f1a7f1f6e55de9285b03ae7e0d3c2a15e044b6e3f9b53bef5609ca05f2/mirakuru-3.0.2-py3-none-any.whl", hash = "sha256:10e5dac4a8f26872c63e9cdfdc01b775aaa2beb3ced98abc497279d2dc525b8f", size = 27583, upload-time = "2026-02-11T19:41:13.578Z" }, +] + [[package]] name = "mistune" version = "3.2.0" @@ -1649,6 +1665,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "port-for" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/a0/80a64e8cc096c7a9d0f546a28994af849b4775afc5e4ee44bf2739a55115/port_for-1.0.0.tar.gz", hash = "sha256:404d161b1b2c82e2f6b31d8646396b4847d02bf5ee10068c92b7263657a14582", size = 21681, upload-time = "2025-09-30T10:22:51.149Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/2c/b1faca65b9728b4ac43f0bee4bb9e7294bd0a62cc2ee59fd59403bf575f6/port_for-1.0.0-py3-none-any.whl", hash = "sha256:35a848b98cf4cc075fe80dc49ae5c3a78e3ca345a23bd39bf5252277b4eef5c2", size = 17544, upload-time = "2025-09-30T10:22:49.878Z" }, +] + [[package]] name = "prek" version = "0.3.0" @@ -1941,6 +1966,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/52/7bbfb6e987d9a8a945f22941a8da63e3529465f1b106ef0e26f5df7c780d/pytest_examples-0.0.18-py3-none-any.whl", hash = "sha256:86c195b98c4e55049a0df3a0a990ca89123b7280473ab57608eecc6c47bcfe9c", size = 18169, upload-time = "2025-05-06T07:46:09.349Z" }, ] +[[package]] +name = "pytest-postgresql" +version = "8.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mirakuru" }, + { name = "packaging" }, + { name = "port-for" }, + { name = "psycopg" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/c8/b9607675904b3e4004c76109741aac6479daecfe6fb38ad89f330c6c5adc/pytest_postgresql-8.0.0.tar.gz", hash = "sha256:26cbd44a0adef76cf4a82a3a2263f0e029bc54b5863556a9bd86ca3edfa91cce", size = 49737, upload-time = "2026-01-23T21:14:34.554Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/cf/2c962ce63904e2b051b644a0036bdeeae8ae05d91e11e829b220c3db9935/pytest_postgresql-8.0.0-py3-none-any.whl", hash = "sha256:125b63b16d630c2dea19807062ed4c96e6123f06058ce65b82b4b14174d6c4b8", size = 40228, upload-time = "2026-01-23T21:14:33.08Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"