diff --git a/Cargo.lock b/Cargo.lock index 9eb2441..4eb3675 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3945,6 +3945,7 @@ dependencies = [ "time", "tokio", "tracing", + "uuid", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ef8e366..caabc3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ tokio = { version = "1.48.0", features = ["rt", "sync", "time"] } indexmap = { version = "2.12.1", features = ["serde"] } base64 = "0.22.1" tracing = { version = "0.1.41", default-features = false, features = ["std", "release_max_level_off"] } +uuid = { version = "1.11.0", features = ["v4"] } # SQLx for types and queries (time feature enables datetime type decoding) sqlx = { version = "0.8.6", features = ["sqlite", "json", "time", "runtime-tokio"] } diff --git a/README.md b/README.md index 9c6b37e..0b5d305 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ const user = await db.fetchOne( ### Transactions -Execute multiple statements atomically: +For most cases, use `executeTransaction()` to run multiple statements atomically: ```typescript const results = await db.executeTransaction([ @@ -250,6 +250,49 @@ const results = await db.executeTransaction([ Transactions use `BEGIN IMMEDIATE`, commit on success, and rollback on any failure. +#### Interruptible Transactions + +**Use interruptible transactions when you need to read data mid-transaction to +decide how to proceed.** For example, inserting a record, reading back its +generated ID or other computed values, then using that data in subsequent writes. + +```typescript +// Begin transaction with initial insert +const tx = await db.executeInterruptibleTransaction([ + ['INSERT INTO orders (user_id, total) VALUES ($1, $2)', [userId, 0]] +]) + +// Read the uncommitted data to get the generated order ID +const orders = await tx.read>( + 'SELECT id FROM orders WHERE user_id = $1 ORDER BY id DESC LIMIT 1', + [userId] +) +const orderId = orders[0].id + +// Continue transaction with the order ID +const tx2 = await tx.continue([ + ['INSERT INTO order_items (order_id, product_id) VALUES ($1, $2)', [orderId, productId]], + ['UPDATE orders SET total = $1 WHERE id = $2', [itemTotal, orderId]] +]) + +// Commit the transaction +await tx2.commit() +``` + +**Important:** + + * Only one interruptible transaction can be active per database at a time + * The write lock is held for the entire duration - keep transactions short + * Uncommitted writes are visible only within the transaction's `read()` method + * Always commit or rollback - abandoned transactions will rollback automatically + on app exit + +To rollback instead of committing: + +```typescript +await tx.rollback() +``` + ### Error Handling ```typescript @@ -296,12 +339,22 @@ await db.remove() // Close and DELETE database file(s) - irreversible! | Method | Description | | ------ | ----------- | | `execute(query, values?)` | Execute write query, returns `{ rowsAffected, lastInsertId }` | -| `executeTransaction(statements)` | Execute statements atomically | +| `executeTransaction(statements)` | Execute statements atomically (use for batch writes) | +| `executeInterruptibleTransaction(statements)` | Begin interruptible transaction, returns `InterruptibleTransaction` | | `fetchAll(query, values?)` | Execute SELECT, return all rows | | `fetchOne(query, values?)` | Execute SELECT, return single row or `undefined` | | `close()` | Close connection, returns `true` if was loaded | | `remove()` | Close and delete database file(s), returns `true` if was loaded | +### InterruptibleTransaction Methods + +| Method | Description | +| ------ | ----------- | +| `read(query, values?)` | Read uncommitted data within this transaction | +| `continue(statements)` | Execute additional statements, returns new `InterruptibleTransaction` | +| `commit()` | Commit transaction and release write lock | +| `rollback()` | Rollback transaction and release write lock | + ### Types ```typescript diff --git a/api-iife.js b/api-iife.js index 2bae4bb..7a028d1 100644 --- a/api-iife.js +++ b/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(){"use strict";async function t(t,e={},n){return window.__TAURI_INTERNALS__.invoke(t,e,n)}"function"==typeof SuppressedError&&SuppressedError;class e{constructor(t){this.path=t}static async load(n,s){const a=await t("plugin:sqlite|load",{db:n,customConfig:s});return new e(a)}static get(t){return new e(t)}async execute(e,n){const[s,a]=await t("plugin:sqlite|execute",{db:this.path,query:e,values:n??[]});return{lastInsertId:a,rowsAffected:s}}async executeTransaction(e){return await t("plugin:sqlite|execute_transaction",{db:this.path,statements:e.map(([t,e])=>({query:t,values:e??[]}))})}async fetchAll(e,n){return await t("plugin:sqlite|fetch_all",{db:this.path,query:e,values:n??[]})}async fetchOne(e,n){return await t("plugin:sqlite|fetch_one",{db:this.path,query:e,values:n??[]})}async close(){return await t("plugin:sqlite|close",{db:this.path})}static async closeAll(){await t("plugin:sqlite|close_all")}async remove(){return await t("plugin:sqlite|remove",{db:this.path})}async getMigrationEvents(){return await t("plugin:sqlite|get_migration_events",{db:this.path})}}return e}();Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} +if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(t){"use strict";async function a(t,a={},n){return window.__TAURI_INTERNALS__.invoke(t,a,n)}"function"==typeof SuppressedError&&SuppressedError;class n{constructor(t,a){this.dbPath=t,this.transactionId=a}async read(t,n){return await a("plugin:sqlite|transaction_read",{token:{dbPath:this.dbPath,transactionId:this.transactionId},query:t,values:n??[]})}async continue(t){const e=await a("plugin:sqlite|transaction_continue",{token:{dbPath:this.dbPath,transactionId:this.transactionId},action:{type:"Continue",statements:t.map(([t,a])=>({query:t,values:a??[]}))}});return new n(e.dbPath,e.transactionId)}async commit(){await a("plugin:sqlite|transaction_continue",{token:{dbPath:this.dbPath,transactionId:this.transactionId},action:{type:"Commit"}})}async rollback(){await a("plugin:sqlite|transaction_continue",{token:{dbPath:this.dbPath,transactionId:this.transactionId},action:{type:"Rollback"}})}}class e{constructor(t){this.path=t}static async load(t,n){const i=await a("plugin:sqlite|load",{db:t,customConfig:n});return new e(i)}static get(t){return new e(t)}async execute(t,n){const[e,i]=await a("plugin:sqlite|execute",{db:this.path,query:t,values:n??[]});return{lastInsertId:i,rowsAffected:e}}async executeTransaction(t){return await a("plugin:sqlite|execute_transaction",{db:this.path,statements:t.map(([t,a])=>({query:t,values:a??[]}))})}async fetchAll(t,n){return await a("plugin:sqlite|fetch_all",{db:this.path,query:t,values:n??[]})}async fetchOne(t,n){return await a("plugin:sqlite|fetch_one",{db:this.path,query:t,values:n??[]})}async close(){return await a("plugin:sqlite|close",{db:this.path})}static async closeAll(){await a("plugin:sqlite|close_all")}async remove(){return await a("plugin:sqlite|remove",{db:this.path})}async executeInterruptibleTransaction(t){const e=await a("plugin:sqlite|execute_interruptible_transaction",{db:this.path,initialStatements:t.map(([t,a])=>({query:t,values:a??[]}))});return new n(e.dbPath,e.transactionId)}async getMigrationEvents(){return await a("plugin:sqlite|get_migration_events",{db:this.path})}}return t.InterruptibleTransaction=n,t.default=e,Object.defineProperty(t,"__esModule",{value:!0}),t}({});Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} diff --git a/guest-js/index.test.ts b/guest-js/index.test.ts index 8d3c352..f68be2d 100644 --- a/guest-js/index.test.ts +++ b/guest-js/index.test.ts @@ -15,11 +15,23 @@ beforeEach(() => { if (cmd === 'plugin:sqlite|load') return (args as { db: string }).db if (cmd === 'plugin:sqlite|execute') return [1, 1] if (cmd === 'plugin:sqlite|execute_transaction') return [] + if (cmd === 'plugin:sqlite|execute_interruptible_transaction') { + return { dbPath: (args as { db: string }).db, transactionId: 'test-tx-id' } + } + if (cmd === 'plugin:sqlite|transaction_continue') { + const action = (args as { action: { type: string } }).action + if (action.type === 'Continue') { + return { dbPath: 'test.db', transactionId: 'test-tx-id' } + } + return undefined + } + if (cmd === 'plugin:sqlite|transaction_read') return [] if (cmd === 'plugin:sqlite|fetch_all') return [] if (cmd === 'plugin:sqlite|fetch_one') return null if (cmd === 'plugin:sqlite|close') return true if (cmd === 'plugin:sqlite|close_all') return undefined if (cmd === 'plugin:sqlite|remove') return true + if (cmd === 'plugin:sqlite|get_migration_events') return [] return undefined }) }) @@ -92,6 +104,69 @@ describe('Database commands', () => { expect(events).toEqual(mockEvents) }) + it('getMigrationEvents - empty array', async () => { + const events = await Database.get('test.db').getMigrationEvents() + expect(lastCmd).toBe('plugin:sqlite|get_migration_events') + expect(lastArgs.db).toBe('test.db') + expect(events).toEqual([]) + }) + + it('executeInterruptibleTransaction', async () => { + const tx = await Database.get('t.db').executeInterruptibleTransaction([ + ['INSERT INTO users (name) VALUES ($1)', ['Alice']] + ]) + expect(lastCmd).toBe('plugin:sqlite|execute_interruptible_transaction') + expect(lastArgs.db).toBe('t.db') + expect(lastArgs.initialStatements).toEqual([ + { query: 'INSERT INTO users (name) VALUES ($1)', values: ['Alice'] } + ]) + expect(tx).toBeInstanceOf(Object) + }) + + it('InterruptibleTransaction.continue()', async () => { + const tx = await Database.get('test.db').executeInterruptibleTransaction([ + ['INSERT INTO users (name) VALUES ($1)', ['Alice']] + ]) + const tx2 = await tx.continue([ + ['INSERT INTO users (name) VALUES ($1)', ['Bob']] + ]) + expect(lastCmd).toBe('plugin:sqlite|transaction_continue') + expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }) + expect((lastArgs.action as { type: string }).type).toBe('Continue') + expect(tx2).toBeInstanceOf(Object) + }) + + it('InterruptibleTransaction.commit()', async () => { + const tx = await Database.get('test.db').executeInterruptibleTransaction([ + ['INSERT INTO users (name) VALUES ($1)', ['Alice']] + ]) + await tx.commit() + expect(lastCmd).toBe('plugin:sqlite|transaction_continue') + expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }) + expect((lastArgs.action as { type: string }).type).toBe('Commit') + }) + + it('InterruptibleTransaction.rollback()', async () => { + const tx = await Database.get('test.db').executeInterruptibleTransaction([ + ['INSERT INTO users (name) VALUES ($1)', ['Alice']] + ]) + await tx.rollback() + expect(lastCmd).toBe('plugin:sqlite|transaction_continue') + expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }) + expect((lastArgs.action as { type: string }).type).toBe('Rollback') + }) + + it('InterruptibleTransaction.read()', async () => { + const tx = await Database.get('test.db').executeInterruptibleTransaction([ + ['INSERT INTO users (name) VALUES ($1)', ['Alice']] + ]) + await tx.read('SELECT * FROM users WHERE name = $1', ['Alice']) + expect(lastCmd).toBe('plugin:sqlite|transaction_read') + expect(lastArgs.token).toEqual({ dbPath: 'test.db', transactionId: 'test-tx-id' }) + expect(lastArgs.query).toBe('SELECT * FROM users WHERE name = $1') + expect(lastArgs.values).toEqual(['Alice']) + }) + it('handles errors from backend', async () => { mockIPC(() => { throw new Error('Database error') diff --git a/guest-js/index.ts b/guest-js/index.ts index 5a606d8..f9f2816 100644 --- a/guest-js/index.ts +++ b/guest-js/index.ts @@ -38,6 +38,122 @@ export interface SqliteError { message: string } +/** + * **InterruptibleTransaction** + * + * Represents an active interruptible transaction that can be continued, committed, or rolled back. + * Provides methods to read uncommitted data and execute additional statements. + */ +export class InterruptibleTransaction { + constructor( + private readonly dbPath: string, + private readonly transactionId: string + ) {} + + /** + * **read** + * + * Read data from the database within this transaction context. + * This allows you to see uncommitted writes from the current transaction. + * + * The query executes on the same connection as the transaction, so you can + * read data that hasn't been committed yet. + * + * @param query - SELECT query to execute + * @param bindValues - Optional parameter values + * @returns Promise that resolves with query results + * + * @example + * ```ts + * const tx = await db.executeInterruptibleTransaction([ + * ['INSERT INTO users (name) VALUES ($1)', ['Alice']] + * ]); + * + * const users = await tx.read( + * 'SELECT * FROM users WHERE name = $1', + * ['Alice'] + * ); + * ``` + */ + async read(query: string, bindValues?: SqlValue[]): Promise { + return await invoke('plugin:sqlite|transaction_read', { + token: { dbPath: this.dbPath, transactionId: this.transactionId }, + query, + values: bindValues ?? [] + }) + } + + /** + * **continue** + * + * Execute additional statements within this transaction and return a new transaction handle. + * + * @param statements - Array of [query, values?] tuples to execute + * @returns Promise that resolves with a new transaction handle + * + * @example + * ```ts + * const tx = await db.executeInterruptibleTransaction([...]); + * const tx2 = await tx.continue([ + * ['INSERT INTO users (name) VALUES ($1)', ['Bob']] + * ]); + * await tx2.commit(); + * ``` + */ + async continue(statements: Array<[string, SqlValue[]?]>): Promise { + const token = await invoke<{ dbPath: string; transactionId: string }>( + 'plugin:sqlite|transaction_continue', + { + token: { dbPath: this.dbPath, transactionId: this.transactionId }, + action: { + type: 'Continue', + statements: statements.map(([query, values]) => ({ + query, + values: values ?? [] + })) + } + } + ) + return new InterruptibleTransaction(token.dbPath, token.transactionId) + } + + /** + * **commit** + * + * Commit this transaction and release the write lock. + * + * @example + * ```ts + * const tx = await db.executeInterruptibleTransaction([...]); + * await tx.commit(); + * ``` + */ + async commit(): Promise { + await invoke('plugin:sqlite|transaction_continue', { + token: { dbPath: this.dbPath, transactionId: this.transactionId }, + action: { type: 'Commit' } + }) + } + + /** + * **rollback** + * + * Rollback this transaction and release the write lock. + * + * @example + * ```ts + * const tx = await db.executeInterruptibleTransaction([...]); + * await tx.rollback(); + * ``` + */ + async rollback(): Promise { + await invoke('plugin:sqlite|transaction_continue', { + token: { dbPath: this.dbPath, transactionId: this.transactionId }, + action: { type: 'Rollback' } + }) + } +} + /** * Custom configuration for SQLite database connection */ @@ -196,6 +312,10 @@ export default class Database { * Executes multiple write statements atomically within a transaction. * All statements either succeed together or fail together. * + * **Use this method** when you have a batch of writes to execute and don't need to + * read data mid-transaction. For transactions that require reading uncommitted data + * to decide how to proceed, use `executeInterruptibleTransaction()` instead. + * * The function automatically: * - Begins a transaction (BEGIN) * - Executes all statements in order @@ -365,6 +485,69 @@ export default class Database { return success } + /** + * **executeInterruptibleTransaction** + * + * Begins an interruptible transaction for cases where you need to **read data mid-transaction + * to decide how to proceed**. For example, inserting a record and then reading its + * generated ID or computed values before continuing with related writes. + * + * The transaction remains open, holding a write lock on the database, until you + * call `commit()` or `rollback()` on the returned transaction handle. + * + * **Use this method when:** + * - You need to read back generated IDs (e.g., AUTOINCREMENT columns) + * - You need to see computed values (e.g., triggers, default values) + * - Your next writes depend on data from earlier writes in the same transaction + * + * **Use `executeTransaction()` instead when:** + * - You just need to execute a batch of writes atomically + * - You know all the data upfront and don't need to read mid-transaction + * + * **Important:** Only one transaction can be active per database at a time. The + * writer connection is held for the entire duration - keep transactions short. + * + * @param initialStatements - Array of [query, values?] tuples to execute initially + * @returns Promise that resolves with an InterruptibleTransaction handle + * + * @example + * ```ts + * // Insert an order and read back its ID + * const tx = await db.executeInterruptibleTransaction([ + * ['INSERT INTO orders (user_id, total) VALUES ($1, $2)', [userId, 0]] + * ]); + * + * // Read the generated order ID + * const orders = await tx.read>( + * 'SELECT id FROM orders WHERE user_id = $1 ORDER BY id DESC LIMIT 1', + * [userId] + * ); + * const orderId = orders[0].id; + * + * // Use the ID in subsequent writes + * const tx2 = await tx.continue([ + * ['INSERT INTO order_items (order_id, product_id) VALUES ($1, $2)', [orderId, productId]] + * ]); + * + * await tx2.commit(); + * ``` + */ + async executeInterruptibleTransaction( + initialStatements: Array<[string, SqlValue[]?]> + ): Promise { + const token = await invoke<{ dbPath: string; transactionId: string }>( + 'plugin:sqlite|execute_interruptible_transaction', + { + db: this.path, + initialStatements: initialStatements.map(([query, values]) => ({ + query, + values: values ?? [] + })) + } + ) + return new InterruptibleTransaction(token.dbPath, token.transactionId) + } + /** * **getMigrationEvents** * diff --git a/rollup.config.js b/rollup.config.js index 5b5a8c4..a558aad 100644 --- a/rollup.config.js +++ b/rollup.config.js @@ -25,11 +25,13 @@ export default [ output: [ { file: pkg.exports.import, - format: 'esm' + format: 'esm', + exports: 'named' }, { file: pkg.exports.require, - format: 'cjs' + format: 'cjs', + exports: 'named' } ], plugins: [ @@ -60,7 +62,8 @@ export default [ banner: "if ('__TAURI__' in window) {", // the last `}` closes the if in the banner footer: `Object.defineProperty(window.__TAURI__, '${pluginJsName}', { value: ${iifeVarName} }) }`, - file: 'api-iife.js' + file: 'api-iife.js', + exports: 'named' }, plugins: [ typescript({ diff --git a/src/commands.rs b/src/commands.rs index 0ed6dcd..bdb706e 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -4,21 +4,36 @@ //! Each command manages database connections through the DbInstances state. use indexmap::IndexMap; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use sqlx_sqlite_conn_mgr::SqliteDatabaseConfig; use tauri::{AppHandle, Runtime, State}; +use uuid::Uuid; use crate::{ DbInstances, Error, MigrationEvent, MigrationStates, MigrationStatus, Result, WriteQueryResult, + transactions::{ + ActiveInterruptibleTransaction, ActiveInterruptibleTransactions, ActiveRegularTransactions, + Statement, + }, wrapper::DatabaseWrapper, }; -/// Statement in a transaction with query and bind values +/// Token representing an active interruptible transaction +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TransactionToken { + pub db_path: String, + pub transaction_id: String, +} + +/// Actions that can be taken on a pausable transaction #[derive(Debug, Deserialize)] -pub struct Statement { - query: String, - values: Vec, +#[serde(tag = "type")] +pub enum TransactionAction { + Continue { statements: Vec }, + Commit, + Rollback, } /// Load/connect to a database and store it in plugin state. @@ -133,6 +148,7 @@ pub async fn execute( #[tauri::command] pub async fn execute_transaction( db_instances: State<'_, DbInstances>, + regular_txs: State<'_, ActiveRegularTransactions>, db: String, statements: Vec, ) -> Result> { @@ -148,9 +164,42 @@ pub async fn execute_transaction( .map(|s| (s.query, s.values)) .collect(); - let results = wrapper.execute_transaction(stmt_tuples).await?; + // Generate unique key for tracking this transaction + let tx_key = format!("{}:{}", db, Uuid::new_v4()); + + // Spawn transaction execution with abort handle for cleanup on exit + let wrapper_clone = wrapper.clone(); + let tx_key_clone = tx_key.clone(); + let regular_txs_clone = regular_txs.inner().clone(); + + let handle = tokio::spawn(async move { + let result = wrapper_clone.execute_transaction(stmt_tuples).await; + + // Remove from tracking when complete (even if result is Err) + regular_txs_clone.remove(&tx_key_clone).await; + + result + }); - Ok(results) + // Track abort handle for cleanup on app exit + regular_txs + .insert(tx_key.clone(), handle.abort_handle()) + .await; + + // Wait for transaction to complete + match handle.await { + Ok(result) => result, + Err(e) => { + // Task panicked or was aborted - ensure cleanup + regular_txs.remove(&tx_key).await; + + if e.is_cancelled() { + Err(Error::Other("Transaction aborted due to app exit".into())) + } else { + Err(Error::Other(format!("Transaction task panicked: {}", e))) + } + } + } } /// Execute a SELECT query returning all matching rows @@ -257,3 +306,148 @@ pub async fn get_migration_events( None => Ok(Vec::new()), } } + +/// Execute initial statements in an interruptible transaction and return a token. +/// +/// This begins a transaction, executes the initial statements, and returns a token +/// that can be used to continue, commit, or rollback the transaction. +/// The writer connection is held for the entire transaction duration. +#[tauri::command] +pub async fn execute_interruptible_transaction( + db_instances: State<'_, DbInstances>, + active_txs: State<'_, ActiveInterruptibleTransactions>, + db: String, + initial_statements: Vec, +) -> Result { + let instances = db_instances.0.read().await; + + let wrapper = instances + .get(&db) + .ok_or_else(|| Error::DatabaseNotLoaded(db.clone()))?; + + // Generate unique transaction ID + let transaction_id = Uuid::new_v4().to_string(); + + // Acquire writer for the entire transaction + let mut writer = wrapper.acquire_writer().await?; + + // Begin transaction + sqlx::query("BEGIN IMMEDIATE").execute(&mut *writer).await?; + + // Execute initial statements + for statement in initial_statements { + let mut q = sqlx::query(&statement.query); + for value in statement.values { + q = crate::wrapper::bind_value(q, value); + } + q.execute(&mut *writer).await?; + } + + // Create abort handle for transaction cleanup on app exit + let abort_handle = tokio::spawn(std::future::pending::<()>()).abort_handle(); + + // Store transaction state + let tx = + ActiveInterruptibleTransaction::new(db.clone(), transaction_id.clone(), writer, abort_handle); + + active_txs.insert(db.clone(), tx).await?; + + Ok(TransactionToken { + db_path: db, + transaction_id, + }) +} + +/// Continue, commit, or rollback an interruptible transaction. +/// +/// Returns a new token if continuing with more statements, or None if committed/rolled back. +#[tauri::command] +pub async fn transaction_continue( + active_txs: State<'_, ActiveInterruptibleTransactions>, + token: TransactionToken, + action: TransactionAction, +) -> Result> { + match action { + TransactionAction::Continue { statements } => { + // Remove transaction to get mutable access + let mut tx = active_txs + .remove(&token.db_path, &token.transaction_id) + .await?; + + // Execute statements on the transaction + match tx.execute_statements(statements).await { + Ok(()) => { + // Re-insert transaction - if this fails, tx is dropped and auto-rolled back + match active_txs.insert(token.db_path.clone(), tx).await { + Ok(()) => Ok(Some(token)), + Err(e) => { + // Transaction lost but will auto-rollback via Drop + Err(e) + } + } + } + Err(e) => { + // Execution failed, explicitly rollback before returning error + let _ = tx.rollback().await; + Err(e) + } + } + } + + TransactionAction::Commit => { + // Remove transaction and commit + let tx = active_txs + .remove(&token.db_path, &token.transaction_id) + .await?; + + tx.commit().await?; + Ok(None) + } + + TransactionAction::Rollback => { + // Remove transaction and rollback + let tx = active_txs + .remove(&token.db_path, &token.transaction_id) + .await?; + + tx.rollback().await?; + Ok(None) + } + } +} + +/// Read from database within an interruptible transaction to see uncommitted writes. +/// +/// This executes a SELECT query on the same connection as the transaction, +/// allowing you to see uncommitted data. +#[tauri::command] +pub async fn transaction_read( + active_txs: State<'_, ActiveInterruptibleTransactions>, + token: TransactionToken, + query: String, + values: Vec, +) -> Result>> { + // Remove transaction to get mutable access + let mut tx = active_txs + .remove(&token.db_path, &token.transaction_id) + .await?; + + // Execute read on the transaction + match tx.read(query, values).await { + Ok(results) => { + // Re-insert transaction - if this fails, tx is dropped and auto-rolled back + match active_txs.insert(token.db_path.clone(), tx).await { + Ok(()) => Ok(results), + Err(e) => { + // Transaction lost but will auto-rollback via Drop + Err(e) + } + } + } + Err(e) => { + // Read failed, explicitly rollback before returning error + let _ = tx.rollback().await; + Err(e) + } + } +} diff --git a/src/error.rs b/src/error.rs index ae15e36..a6b6405 100644 --- a/src/error.rs +++ b/src/error.rs @@ -51,6 +51,22 @@ pub enum Error { transaction_error: String, rollback_error: String, }, + + /// Transaction already active for this database. + #[error("transaction already active for database: {0}")] + TransactionAlreadyActive(String), + + /// No active transaction for this database. + #[error("no active transaction for database: {0}")] + NoActiveTransaction(String), + + /// Invalid transaction token provided. + #[error("invalid transaction token")] + InvalidTransactionToken, + + /// Generic error for operations that don't fit other categories. + #[error("{0}")] + Other(String), } impl Error { @@ -74,6 +90,10 @@ impl Error { Error::Io(_) => "IO_ERROR".to_string(), Error::MultipleRowsReturned(_) => "MULTIPLE_ROWS_RETURNED".to_string(), Error::TransactionRollbackFailed { .. } => "TRANSACTION_ROLLBACK_FAILED".to_string(), + Error::TransactionAlreadyActive(_) => "TRANSACTION_ALREADY_ACTIVE".to_string(), + Error::NoActiveTransaction(_) => "NO_ACTIVE_TRANSACTION".to_string(), + Error::InvalidTransactionToken => "INVALID_TRANSACTION_TOKEN".to_string(), + Error::Other(_) => "ERROR".to_string(), } } } diff --git a/src/lib.rs b/src/lib.rs index c09773d..11dc7fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,10 +10,12 @@ use tracing::{debug, error, info, trace, warn}; mod commands; mod decode; mod error; +mod transactions; mod wrapper; pub use error::{Error, Result}; pub use sqlx_sqlite_conn_mgr::Migrator as SqliteMigrator; +pub use transactions::{ActiveInterruptibleTransactions, ActiveRegularTransactions}; pub use wrapper::{DatabaseWrapper, WriteQueryResult}; /// Database instances managed by the plugin. @@ -159,6 +161,9 @@ impl Builder { commands::load, commands::execute, commands::execute_transaction, + commands::execute_interruptible_transaction, + commands::transaction_continue, + commands::transaction_read, commands::fetch_all, commands::fetch_one, commands::close, @@ -169,6 +174,8 @@ impl Builder { .setup(move |app, _api| { app.manage(DbInstances::default()); app.manage(MigrationStates::default()); + app.manage(ActiveInterruptibleTransactions::default()); + app.manage(ActiveRegularTransactions::default()); // Initialize migration states as Pending for all registered databases let migration_states = app.state::(); @@ -200,7 +207,7 @@ impl Builder { .on_event(|app, event| { match event { RunEvent::ExitRequested { api, code, .. } => { - info!("App exit requested (code: {:?}) - closing databases before exit", code); + info!("App exit requested (code: {:?}) - cleaning up transactions and databases", code); // Prevent immediate exit so we can close connections and checkpoint WAL api.prevent_exit(); @@ -210,19 +217,26 @@ impl Builder { let handle = match tokio::runtime::Handle::try_current() { Ok(h) => h, Err(_) => { - warn!("No tokio runtime available for database cleanup"); + warn!("No tokio runtime available for cleanup"); app_handle.exit(code.unwrap_or(0)); return; } }; - let instances = app.state::().inner().clone(); + let instances_clone = app.state::().inner().clone(); + let interruptible_txs_clone = app.state::().inner().clone(); + let regular_txs_clone = app.state::().inner().clone(); - // Spawn a blocking thread to close databases + // Spawn a blocking thread to abort transactions and close databases // (block_in_place panics on current_thread runtime) let cleanup_result = std::thread::spawn(move || { handle.block_on(async { - let mut guard = instances.0.write().await; + // First, abort all active transactions + debug!("Aborting active transactions"); + transactions::cleanup_all_transactions(&interruptible_txs_clone, ®ular_txs_clone).await; + + // Then close databases + let mut guard = instances_clone.0.write().await; let wrappers: Vec = guard.drain().map(|(_, v)| v).collect(); diff --git a/src/transactions.rs b/src/transactions.rs new file mode 100644 index 0000000..e3a99ba --- /dev/null +++ b/src/transactions.rs @@ -0,0 +1,237 @@ +//! Transaction management for interruptible transactions + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; + +use indexmap::IndexMap; +use serde::Deserialize; +use serde_json::Value as JsonValue; +use sqlx::{Column, Row}; +use sqlx_sqlite_conn_mgr::WriteGuard; +use tokio::sync::RwLock; +use tokio::task::AbortHandle; +use tracing::debug; + +use crate::{Error, Result}; + +/// Active transaction state holding the writer and metadata +pub struct ActiveInterruptibleTransaction { + db_path: String, + transaction_id: String, + writer: WriteGuard, + abort_handle: AbortHandle, + created_at: Instant, +} + +impl ActiveInterruptibleTransaction { + pub fn new( + db_path: String, + transaction_id: String, + writer: WriteGuard, + abort_handle: AbortHandle, + ) -> Self { + Self { + db_path, + transaction_id, + writer, + abort_handle, + created_at: Instant::now(), + } + } + + pub fn db_path(&self) -> &str { + &self.db_path + } + + pub fn transaction_id(&self) -> &str { + &self.transaction_id + } + + pub fn created_at(&self) -> Instant { + self.created_at + } + + pub fn validate_token(&self, token_id: &str) -> Result<()> { + if self.transaction_id != token_id { + return Err(Error::InvalidTransactionToken); + } + Ok(()) + } + + /// Execute a read query within this transaction and return decoded results + pub async fn read( + &mut self, + query: String, + values: Vec, + ) -> Result>> { + let mut q = sqlx::query(&query); + for value in values { + q = crate::wrapper::bind_value(q, value); + } + + let rows = q.fetch_all(&mut *self.writer).await?; + + let mut results = Vec::new(); + for row in rows { + let mut value = IndexMap::default(); + for (i, column) in row.columns().iter().enumerate() { + let v = row.try_get_raw(i)?; + let v = crate::decode::to_json(v)?; + value.insert(column.name().to_string(), v); + } + results.push(value); + } + + Ok(results) + } + + /// Execute statements on this transaction + pub async fn execute_statements(&mut self, statements: Vec) -> Result<()> { + for statement in statements { + let mut q = sqlx::query(&statement.query); + for value in statement.values { + q = crate::wrapper::bind_value(q, value); + } + q.execute(&mut *self.writer).await?; + } + Ok(()) + } + + /// Commit this transaction + pub async fn commit(mut self) -> Result<()> { + sqlx::query("COMMIT").execute(&mut *self.writer).await?; + debug!("Transaction committed for db: {}", self.db_path); + Ok(()) + } + + /// Rollback this transaction + pub async fn rollback(mut self) -> Result<()> { + sqlx::query("ROLLBACK").execute(&mut *self.writer).await?; + debug!("Transaction rolled back for db: {}", self.db_path); + Ok(()) + } +} + +/// Statement in a transaction with query and bind values +#[derive(Debug, Deserialize)] +pub struct Statement { + pub query: String, + pub values: Vec, +} + +impl Drop for ActiveInterruptibleTransaction { + fn drop(&mut self) { + // On drop, the WriteGuard is dropped which returns connection to pool. + // SQLite will automatically ROLLBACK the transaction when the connection + // is returned to the pool if no explicit COMMIT was issued. + debug!( + "Dropping transaction for db: {}, tx_id: {} (will auto-rollback)", + self.db_path, self.transaction_id + ); + } +} + +/// Global state tracking all active interruptible transactions +#[derive(Clone, Default)] +pub struct ActiveInterruptibleTransactions( + Arc>>, +); + +impl ActiveInterruptibleTransactions { + pub async fn insert(&self, db_path: String, tx: ActiveInterruptibleTransaction) -> Result<()> { + use std::collections::hash_map::Entry; + let mut txs = self.0.write().await; + + // Ensure only one transaction per database using Entry API + match txs.entry(db_path.clone()) { + Entry::Vacant(e) => { + e.insert(tx); + Ok(()) + } + Entry::Occupied(_) => Err(Error::TransactionAlreadyActive(db_path)), + } + } + + pub async fn abort_all(&self) { + let mut txs = self.0.write().await; + debug!("Aborting {} active interruptible transaction(s)", txs.len()); + + for (db_path, tx) in txs.iter() { + debug!( + "Aborting interruptible transaction for database: {}", + db_path + ); + tx.abort_handle.abort(); + } + + // Clear all transactions to drop WriteGuards and release locks + txs.clear(); + } + + /// Remove and return transaction for commit/rollback + pub async fn remove( + &self, + db_path: &str, + token_id: &str, + ) -> Result { + let mut txs = self.0.write().await; + + // Validate token before removal + let tx = txs + .get(db_path) + .ok_or_else(|| Error::NoActiveTransaction(db_path.to_string()))?; + + tx.validate_token(token_id)?; + + // Safe unwrap: we just confirmed the key exists above + Ok(txs.remove(db_path).unwrap()) + } +} + +/// Tracking for regular (non-pausable) transactions that are in-flight +/// This allows us to abort them on app exit +#[derive(Clone, Default)] +pub struct ActiveRegularTransactions(Arc>>); + +impl ActiveRegularTransactions { + pub async fn insert(&self, key: String, abort_handle: AbortHandle) { + let mut txs = self.0.write().await; + txs.insert(key, abort_handle); + } + + pub async fn remove(&self, key: &str) { + let mut txs = self.0.write().await; + txs.remove(key); + } + + pub async fn abort_all(&self) { + let mut txs = self.0.write().await; + debug!("Aborting {} active regular transaction(s)", txs.len()); + + for (key, abort_handle) in txs.iter() { + debug!("Aborting regular transaction: {}", key); + abort_handle.abort(); + } + + // Clear all tracked transactions to prevent memory leak + txs.clear(); + } +} + +/// Cleanup all transactions on app exit +pub async fn cleanup_all_transactions( + interruptible: &ActiveInterruptibleTransactions, + regular: &ActiveRegularTransactions, +) { + debug!("Cleaning up all active transactions"); + + // Abort all transaction tasks + interruptible.abort_all().await; + regular.abort_all().await; + + // Interruptible transactions will auto-rollback when dropped + // Regular transactions will also auto-rollback when aborted task cleans up + + debug!("Transaction cleanup initiated"); +} diff --git a/src/wrapper.rs b/src/wrapper.rs index 1b60766..ffb1856 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -24,11 +24,17 @@ pub struct WriteQueryResult { } /// Wrapper around SqliteDatabase that adapts it for the plugin interface +#[derive(Clone)] pub struct DatabaseWrapper { inner: Arc, } impl DatabaseWrapper { + /// Acquire writer connection (for pausable transactions) + pub async fn acquire_writer(&self) -> Result { + Ok(self.inner.acquire_writer().await?) + } + /// Connect to a SQLite database via the connection manager pub async fn connect( path: &str, @@ -239,7 +245,7 @@ impl DatabaseWrapper { } /// Helper function to bind a JSON value to a SQLx query -fn bind_value<'a>( +pub(crate) fn bind_value<'a>( query: sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>>, value: JsonValue, ) -> sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>> {