diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index 86390b2224..a04a274551 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -1093,7 +1093,7 @@ mod tests { .await?; // 5. Verify success after retry - assert_eq!(result, 1); + assert_eq!(result.result, 1); Ok(()) } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index f01ca5a1f4..659edeea59 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -18,6 +18,7 @@ use crate::database_client::DatabaseClient; use crate::error::internal_error; use crate::model::BeginTransactionRequest; use crate::model::CommitRequest; +use crate::model::CommitResponse; use crate::model::ExecuteBatchDmlRequest; use crate::model::RollbackRequest; use crate::model::TransactionOptions; @@ -53,6 +54,7 @@ pub(crate) struct ReadWriteTransactionBuilder { transaction_tag: Option, max_commit_delay: Option, pub(crate) session_name: String, + return_commit_stats: bool, commit_priority: Priority, } @@ -65,6 +67,7 @@ impl ReadWriteTransactionBuilder { transaction_tag: None, max_commit_delay: None, session_name, + return_commit_stats: false, commit_priority: Priority::Unspecified, } } @@ -114,6 +117,11 @@ impl ReadWriteTransactionBuilder { self } + pub(crate) fn with_return_commit_stats(mut self, return_stats: bool) -> Self { + self.return_commit_stats = return_stats; + self + } + pub(crate) async fn begin_transaction( &self, deadline: Option, @@ -156,6 +164,7 @@ impl ReadWriteTransactionBuilder { }, seqno: Arc::new(AtomicI64::new(1)), max_commit_delay: self.max_commit_delay, + return_commit_stats: self.return_commit_stats, deadline, commit_priority: self.commit_priority.clone(), }) @@ -169,6 +178,7 @@ pub struct ReadWriteTransaction { pub(crate) deadline: Option, seqno: Arc, max_commit_delay: Option, + return_commit_stats: bool, commit_priority: Priority, } @@ -360,15 +370,16 @@ impl ReadWriteTransaction { } /// Commits the transaction. - pub(crate) async fn commit(self) -> crate::Result { + pub(crate) async fn commit(self) -> crate::Result { let transaction_id = self.transaction_id()?; let precommit_token = self.context.precommit_token_tracker.get(); let request = CommitRequest::default() .set_session(self.context.session_name.clone()) .set_transaction_id(transaction_id.clone()) .set_or_clear_precommit_token(precommit_token) - .set_or_clear_request_options(self.commit_request_options()) - .set_or_clear_max_commit_delay(self.max_commit_delay); + .set_or_clear_request_options(self.context.amend_request_options(None)) + .set_or_clear_max_commit_delay(self.max_commit_delay) + .set_return_commit_stats(self.return_commit_stats); // TODO(#4972): make request options configurable let mut gax_options = GaxRequestOptions::default(); @@ -402,10 +413,7 @@ impl ReadWriteTransaction { response }; - let timestamp = response - .commit_timestamp - .ok_or_else(|| internal_error("No commit timestamp returned"))?; - Ok(timestamp) + Ok(response) } /// Rolls back the transaction. @@ -570,7 +578,13 @@ mod tests { assert_eq!(count, 1); let timestamp = tx.commit().await.unwrap(); - assert_eq!(timestamp.seconds(), 1001); + assert_eq!( + timestamp + .commit_timestamp + .expect("Commit timestamp should be present") + .seconds(), + 1001 + ); } #[tokio::test] @@ -636,7 +650,12 @@ mod tests { assert_eq!(count, 1); let ts = tx.commit().await.expect("Failed to commit"); - assert_eq!(ts.seconds(), 123456789); + assert_eq!( + ts.commit_timestamp + .expect("Commit timestamp should be present") + .seconds(), + 123456789 + ); } #[tokio::test] @@ -1066,7 +1085,12 @@ mod tests { .expect("Failed to execute update"); } let ts = tx.commit().await.expect("Failed to commit transaction"); - assert_eq!(ts.seconds(), 12345); + assert_eq!( + ts.commit_timestamp + .expect("Commit timestamp should be present") + .seconds(), + 12345 + ); } #[tokio::test] @@ -1135,7 +1159,12 @@ mod tests { .expect("Failed to build transaction"); let ts = tx.commit().await.expect("Failed to commit transaction"); - assert_eq!(ts.seconds(), 9999); + assert_eq!( + ts.commit_timestamp + .expect("Commit timestamp should be present") + .seconds(), + 9999 + ); } #[tokio::test] @@ -1176,6 +1205,11 @@ mod tests { .expect("Failed to build transaction"); let ts = tx.commit().await.expect("Failed to commit"); - assert_eq!(ts.seconds(), 123456789); + assert_eq!( + ts.commit_timestamp + .expect("Commit timestamp should be present") + .seconds(), + 123456789 + ); } } diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index 505678e318..3fb7ac2a7f 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::database_client::DatabaseClient; +use crate::model::CommitResponse; use crate::model::request_options::Priority; use crate::model::transaction_options::IsolationLevel; use crate::model::transaction_options::read_write::ReadLockMode; @@ -228,6 +229,37 @@ impl TransactionRunnerBuilder { self } + /// Sets whether to return commit stats for the transaction. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::{Spanner, Statement}; + /// # async fn run_tx(client: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// let runner = db_client.read_write_transaction() + /// .with_return_commit_stats(true) + /// .build() + /// .await?; + /// + /// let result = runner.run(async |transaction| { + /// let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build(); + /// transaction.execute_update(statement).await?; + /// Ok(42) + /// }).await?; + /// + /// if let Some(stats) = result.commit_response.commit_stats { + /// println!("Mutation count: {}", stats.mutation_count); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// See also: + pub fn with_return_commit_stats(mut self, return_stats: bool) -> Self { + self.builder = self.builder.with_return_commit_stats(return_stats); + self + } + /// Sets the retry policy for the transaction. /// /// # Example @@ -283,6 +315,16 @@ impl TransactionRunnerBuilder { } } +/// Result of a read/write transaction executed by a [TransactionRunner]. +#[derive(Debug)] +#[non_exhaustive] +pub struct TransactionResult { + /// The result returned by the closure executed within the transaction. + pub result: T, + /// The response from the commit RPC. + pub commit_response: CommitResponse, +} + /// A runner for read/write transactions. Aborted transactions are automatically retried. pub struct TransactionRunner { builder: ReadWriteTransactionBuilder, @@ -304,7 +346,7 @@ impl TransactionRunner { /// let result = runner.run(async |transaction| { /// let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build(); /// transaction.execute_update(statement).await?; - /// Ok(42) // This will be returned by runner.run() + /// Ok(42) /// }).await?; /// # Ok(()) /// # } @@ -316,7 +358,7 @@ impl TransactionRunner { /// The transaction is automatically committed if the closure returns `Ok`. /// If the closure returns `Err`, the transaction will be rolled back and /// the error will be propagated. - pub async fn run(mut self, mut work: F) -> crate::Result + pub async fn run(mut self, mut work: F) -> crate::Result> where F: std::ops::AsyncFnMut(ReadWriteTransaction) -> crate::Result, { @@ -344,8 +386,11 @@ impl TransactionRunner { } }; - transaction.commit().await?; - Ok::(result) + let commit_response = transaction.commit().await?; + Ok::, crate::Error>(TransactionResult { + result, + commit_response, + }) } .await; @@ -377,6 +422,8 @@ mod tests { use crate::transaction_retry_policy::tests::create_aborted_status; use gaxi::grpc::tonic; use spanner_grpc_mock::google::spanner::v1; + use spanner_grpc_mock::google::spanner::v1::CommitResponse; + use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats; use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode; fn expect_begin_transaction( @@ -413,6 +460,7 @@ mod tests { Ok(count) }) .await + .map(|res| res.result) } fn commit_response() -> Result, tonic::Status> { @@ -471,6 +519,57 @@ mod tests { assert_eq!(res, 1); } + #[tokio::test] + async fn run_success_with_commit_stats() { + let mut mock = create_session_mock(); + + expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]); + + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET active = true"); + row_count_exact_response(1) + }); + + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + assert!(req.return_commit_stats); + Ok(tonic::Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + commit_stats: Some(CommitStats { mutation_count: 5 }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let runner = TransactionRunnerBuilder::new(db_client) + .with_return_commit_stats(true) + .build() + .await + .unwrap(); + + let res = runner + .run(async |tx| { + let count = tx.execute_update("UPDATE Users SET active = true").await?; + Ok(count) + }) + .await + .unwrap(); + + assert_eq!(res.result, 1); + assert!(res.commit_response.commit_stats.is_some()); + assert_eq!( + res.commit_response + .commit_stats + .expect("Commit stats should be present") + .mutation_count, + 5 + ); + } + #[tokio::test] async fn run_with_aborted_retry() -> anyhow::Result<()> { let mut mock = create_session_mock(); @@ -743,7 +842,7 @@ mod tests { .await .expect("transaction failed"); - assert_eq!(res, vec![5]); + assert_eq!(res.result, vec![5]); assert_eq!(attempt_counter, 2); } @@ -803,7 +902,7 @@ mod tests { }) .await?; - assert_eq!(res, 5); + assert_eq!(res.result, 5); Ok(()) } @@ -844,7 +943,7 @@ mod tests { }) .await?; - assert_eq!(res, 5); + assert_eq!(res.result, 5); Ok(()) } @@ -883,7 +982,7 @@ mod tests { Ok(count) }) .await?; - assert_eq!(res, 1); + assert_eq!(res.result, 1); Ok(()) } } diff --git a/src/spanner/src/write_only_transaction.rs b/src/spanner/src/write_only_transaction.rs index df381770e0..4a45442f5c 100644 --- a/src/spanner/src/write_only_transaction.rs +++ b/src/spanner/src/write_only_transaction.rs @@ -32,6 +32,7 @@ pub struct WriteOnlyTransactionBuilder { max_commit_delay: Option, retry_policy: Box, exclude_txn_from_change_streams: bool, + return_commit_stats: bool, commit_priority: Priority, } @@ -43,6 +44,7 @@ impl WriteOnlyTransactionBuilder { max_commit_delay: None, retry_policy: Box::new(BasicTransactionRetryPolicy::default()), exclude_txn_from_change_streams: false, + return_commit_stats: false, commit_priority: Priority::Unspecified, } } @@ -137,6 +139,37 @@ impl WriteOnlyTransactionBuilder { self } + /// Sets whether to return commit stats for the transaction. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::{Mutation, Spanner}; + /// # async fn test_doc() -> Result<(), Box> { + /// # let client = Spanner::builder().build().await?; + /// # let db = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// let mutation = Mutation::new_insert_builder("Users") + /// .set("UserId").to(&1) + /// .build(); + /// + /// let response = db.write_only_transaction() + /// .with_return_commit_stats(true) + /// .build() + /// .write(vec![mutation]) + /// .await?; + /// + /// if let Some(stats) = response.commit_stats { + /// println!("Mutation count: {}", stats.mutation_count); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// See also: + pub fn with_return_commit_stats(mut self, return_stats: bool) -> Self { + self.return_commit_stats = return_stats; + self + } + /// Sets the retry policy for the transaction. /// /// # Example @@ -186,6 +219,7 @@ impl WriteOnlyTransactionBuilder { max_commit_delay: self.max_commit_delay, retry_policy: self.retry_policy, exclude_txn_from_change_streams: self.exclude_txn_from_change_streams, + return_commit_stats: self.return_commit_stats, commit_priority: self.commit_priority, } } @@ -201,6 +235,7 @@ pub struct WriteOnlyTransaction { max_commit_delay: Option, retry_policy: Box, exclude_txn_from_change_streams: bool, + return_commit_stats: bool, commit_priority: Priority, } @@ -285,7 +320,8 @@ impl WriteOnlyTransaction { .set_transaction_id(tx.id.clone()) .set_request_options(req_options.clone()) .set_or_clear_precommit_token(tx.precommit_token) - .set_or_clear_max_commit_delay(self.max_commit_delay); + .set_or_clear_max_commit_delay(self.max_commit_delay) + .set_return_commit_stats(self.return_commit_stats); let response = client .spanner @@ -359,7 +395,8 @@ impl WriteOnlyTransaction { .set_mutations(mutations.into_iter().map(|m| m.build_proto())) .set_single_use_transaction(Box::new(single_use)) .set_request_options(req_options) - .set_or_clear_max_commit_delay(self.max_commit_delay); + .set_or_clear_max_commit_delay(self.max_commit_delay) + .set_return_commit_stats(self.return_commit_stats); let client = self.client; retry_aborted(&*self.retry_policy, || { @@ -388,6 +425,7 @@ mod tests { use spanner_grpc_mock::google::spanner::v1::CommitResponse; use spanner_grpc_mock::google::spanner::v1::Session; use spanner_grpc_mock::google::spanner::v1::Transaction; + use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats; use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode; use wkt::Duration; @@ -569,6 +607,99 @@ mod tests { ); } + #[tokio::test] + async fn write_at_least_once_with_commit_stats() -> anyhow::Result<()> { + let mut mock = spanner_grpc_mock::MockSpanner::new(); + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + assert!(req.return_commit_stats); + + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1234, + nanos: 0, + }), + commit_stats: Some(CommitStats { mutation_count: 5 }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let mutation = Mutation::new_insert_or_update_builder("Users") + .set("UserId") + .to(&1) + .build(); + + let res = db_client + .write_only_transaction() + .with_return_commit_stats(true) + .build() + .write_at_least_once(vec![mutation]) + .await?; + + let stats = res.commit_stats.expect("Commit stats should be present"); + assert_eq!(stats.mutation_count, 5); + Ok(()) + } + + #[tokio::test] + async fn write_with_commit_stats() -> anyhow::Result<()> { + let mut mock = spanner_grpc_mock::MockSpanner::new(); + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().once().returning(|_| { + Ok(Response::new(Transaction { + id: vec![42], + ..Default::default() + })) + }); + + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + assert!(req.return_commit_stats); + + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 5678, + nanos: 0, + }), + commit_stats: Some(CommitStats { mutation_count: 10 }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let mutation = Mutation::new_insert_or_update_builder("Users") + .set("UserId") + .to(&1) + .build(); + + let res = db_client + .write_only_transaction() + .with_return_commit_stats(true) + .build() + .write(vec![mutation]) + .await?; + + let stats = res.commit_stats.expect("Commit stats should be present"); + assert_eq!(stats.mutation_count, 10); + Ok(()) + } + #[tokio::test] async fn write_at_least_once_with_exclude_txn_from_change_streams() { let mut mock = spanner_grpc_mock::MockSpanner::new(); diff --git a/tests/spanner/src/directed_read.rs b/tests/spanner/src/directed_read.rs index 789ee8967e..738ada2839 100644 --- a/tests/spanner/src/directed_read.rs +++ b/tests/spanner/src/directed_read.rs @@ -60,7 +60,8 @@ pub async fn read_write_with_directed_read_error(db_client: &DatabaseClient) -> let _ = rs.next().await; Ok(()) }) - .await; + .await + .map(|res| res.result); assert!( result.is_err(), diff --git a/tests/spanner/src/read_write_transaction.rs b/tests/spanner/src/read_write_transaction.rs index d81e69de43..3ff109ac4b 100644 --- a/tests/spanner/src/read_write_transaction.rs +++ b/tests/spanner/src/read_write_transaction.rs @@ -134,7 +134,8 @@ pub async fn rolled_back_read_write_transaction(db_client: &DatabaseClient) -> a "Simulated error to trigger rollback", ))) }) - .await; + .await + .map(|res| res.result); assert!(res.is_err(), "Transaction should return an error"); @@ -245,7 +246,8 @@ pub async fn concurrent_read_write_transaction_retries( Ok(()) }) - .await; + .await + .map(|res| res.result); res.expect("Transaction failed"); }); handles.push(handle);