diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index ac65953e7e..86390b2224 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -250,6 +250,7 @@ mod tests { use super::*; use crate::model::CreateSessionRequest; use crate::result_set::tests::adapt; + use gaxi::grpc::tonic::MetadataMap; use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status}; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use google_cloud_gax::backoff_policy::BackoffPolicy; @@ -258,9 +259,15 @@ mod tests { use google_cloud_test_macros::tokio_test_no_panics; use spanner_grpc_mock::google::rpc as mock_rpc; use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::google::spanner::v1::CommitResponse; + use spanner_grpc_mock::google::spanner::v1::ResultSet; + use spanner_grpc_mock::google::spanner::v1::ResultSetStats; use spanner_grpc_mock::google::spanner::v1::Session; + use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount; use spanner_grpc_mock::{MockSpanner, start}; use static_assertions::{assert_impl_all, assert_not_impl_any}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; mockall::mock! { @@ -1091,6 +1098,275 @@ mod tests { Ok(()) } + fn parse_timeout(metadata: &MetadataMap) -> u64 { + let timeout = metadata + .get("grpc-timeout") + .expect("grpc-timeout header should be present"); + let timeout_str = timeout + .to_str() + .expect("grpc-timeout should be a valid string"); + if timeout_str.ends_with('u') { + timeout_str + .trim_end_matches('u') + .parse() + .expect("valid u64") + } else if timeout_str.ends_with('m') { + timeout_str + .trim_end_matches('m') + .parse::() + .expect("valid u64") + * 1000 + } else if timeout_str.ends_with('n') { + timeout_str + .trim_end_matches('n') + .parse::() + .expect("valid u64") + / 1000 + } else { + panic!("Unknown timeout unit in {}", timeout_str); + } + } + + #[tokio_test_no_panics] + async fn transaction_timeout_respected() -> anyhow::Result<()> { + use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt}; + use spanner_grpc_mock::google::spanner::v1::Transaction; + + // 1. Setup Mock Server + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + mock.expect_begin_transaction().returning(|_| { + Ok(Response::new(Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + + mock.expect_commit().once().returning(|_| { + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 12345, + nanos: 0, + }), + ..Default::default() + })) + }); + + // Mock execute_sql to first fail and then succeed, checking timeout header on both + let mut seq = mockall::Sequence::new(); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 100000, + "Expected timeout to be <= 100ms, got {}", + timeout_val + ); + Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded")) + }); + + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 100000, + "Expected timeout to be <= 100ms, got {}", + timeout_val + ); + + let res = ResultSet { + stats: Some(ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(res)) + }); + + // 2. Initialize Client + let (address, _server) = start("127.0.0.1:0", mock).await?; + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + + // 3. Setup Transaction Runner with 100ms timeout + let runner = db + .read_write_transaction() + .with_transaction_timeout(Duration::from_millis(100)) + .build() + .await?; + + // 4. Run transaction and expect success after retry + let result = runner + .run(async |tx| { + let mut mock_backoff = MockBackoffPolicy::new(); + mock_backoff + .expect_on_failure() + .times(1) + .returning(|_| Duration::from_nanos(1)); + + let retry_policy = Aip194Strict.continue_on_too_many_requests(); + + let stmt = Statement::builder("SELECT 1") + .with_retry_policy(retry_policy) + .with_backoff_policy(mock_backoff) + .build(); + tx.execute_update(stmt).await?; + Ok(()) + }) + .await; + + result.expect("Transaction should have succeeded"); + + Ok(()) + } + + #[tokio::test] + async fn transaction_timeout_ticks_down() -> anyhow::Result<()> { + use spanner_grpc_mock::google::spanner::v1::Transaction; + + let mut mock = MockSpanner::new(); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) + }); + + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Transaction { + id: vec![1], + ..Default::default() + })) + }); + + let previous_timeout = Arc::new(AtomicU64::new(0)); + let prev_clone1 = previous_timeout.clone(); + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + assert!( + timeout_val <= 500000, + "Expected timeout to be <= 500ms, got {}", + timeout_val + ); + prev_clone1.store(timeout_val, Ordering::SeqCst); + Err(Status::new(GrpcCode::Aborted, "Aborted")) + }); + + // Second attempt: Checks that timeout is <= previous + mock.expect_begin_transaction() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Transaction { + id: vec![2], + ..Default::default() + })) + }); + + let prev_clone2 = previous_timeout.clone(); + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + let prev = prev_clone2.load(Ordering::SeqCst); + assert!( + timeout_val <= prev, + "Timeout should tick down between attempts or be equal, got {} and {}", + timeout_val, + prev + ); + prev_clone2.store(timeout_val, Ordering::SeqCst); // store for next check + + let res = ResultSet { + stats: Some(ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }; + Ok(Response::new(res)) + }); + + let prev_clone3 = previous_timeout.clone(); + mock.expect_commit().once().returning(move |req| { + let timeout_val = parse_timeout(req.metadata()); + let prev = prev_clone3.load(Ordering::SeqCst); + assert!( + timeout_val < prev, + "Timeout should be smaller for commit, got {} and {}", + timeout_val, + prev + ); + + Ok(Response::new(CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 12345, + nanos: 0, + }), + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + let db = client + .database_client("projects/p/instances/i/databases/d") + .build() + .await?; + + let runner = db + .read_write_transaction() + .with_transaction_timeout(Duration::from_millis(500)) + .build() + .await?; + + let result = runner + .run(async |tx| { + let stmt = Statement::builder("SELECT 1").build(); + tx.execute_update(stmt).await?; + Ok(()) + }) + .await; + + result.expect("Transaction should have succeeded"); + + Ok(()) + } + #[test] fn test_parse_emulator_endpoint() { assert_eq!( diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 9c021a7b80..f01ca5a1f4 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -34,8 +34,15 @@ use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::ReadContext; use crate::result_set::ResultSet; use crate::statement::Statement; +use google_cloud_gax::error::Error as GaxError; +use google_cloud_gax::options::RequestOptions as GaxRequestOptions; +use google_cloud_gax::retry_policy::RetryPolicy; +use google_cloud_gax::retry_result::RetryResult; +use google_cloud_gax::retry_state::RetryState; use std::sync::Arc; use std::sync::atomic::{AtomicI64, Ordering}; +use std::time::Duration as StdDuration; +use tokio::time::Instant; use wkt::Duration; /// A builder for [ReadWriteTransaction]. @@ -107,7 +114,10 @@ impl ReadWriteTransactionBuilder { self } - pub(crate) async fn begin_transaction(&self) -> crate::Result { + pub(crate) async fn begin_transaction( + &self, + deadline: Option, + ) -> crate::Result { let session_name = self.session_name.clone(); let mut request = BeginTransactionRequest::default() .set_session(session_name.clone()) @@ -119,10 +129,16 @@ impl ReadWriteTransactionBuilder { } // TODO(#4972): make request options configurable + let mut options = RequestOptions::default(); + if let Some(d) = deadline { + let remaining = d.saturating_duration_since(Instant::now()); + options.set_attempt_timeout(remaining); + } + let response = self .client .spanner - .begin_transaction(request, RequestOptions::default()) + .begin_transaction(request, options) .await?; let transaction_selector = @@ -140,6 +156,7 @@ impl ReadWriteTransactionBuilder { }, seqno: Arc::new(AtomicI64::new(1)), max_commit_delay: self.max_commit_delay, + deadline, commit_priority: self.commit_priority.clone(), }) } @@ -149,6 +166,7 @@ impl ReadWriteTransactionBuilder { #[derive(Clone, Debug)] pub struct ReadWriteTransaction { pub(crate) context: ReadContext, + pub(crate) deadline: Option, seqno: Arc, max_commit_delay: Option, commit_priority: Priority, @@ -160,7 +178,14 @@ impl ReadWriteTransaction { &self, statement: T, ) -> crate::Result { - self.context.execute_query(statement).await + if self.deadline.is_none() { + return self.context.execute_query(statement).await; + } + let stmt = statement.into(); + let mut gax_options = stmt.gax_options().clone(); + self.apply_transaction_timeout(&mut gax_options); + let stmt = stmt.with_gax_options(gax_options); + self.context.execute_query(stmt).await } /// Reads rows from the database using key lookups and scans, as a simple key/value style alternative to execute_query. @@ -168,14 +193,22 @@ impl ReadWriteTransaction { &self, read: T, ) -> crate::Result { - self.context.execute_read(read).await + if self.deadline.is_none() { + return self.context.execute_read(read).await; + } + let mut req = read.into(); + self.apply_transaction_timeout(&mut req.gax_options); + self.context.execute_read(req).await } /// Executes an update using this transaction. pub async fn execute_update>(&self, statement: T) -> crate::Result { - let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let statement = statement.into(); - let gax_options = statement.gax_options().clone(); + let mut gax_options = statement.gax_options().clone(); + if self.deadline.is_some() { + self.apply_transaction_timeout(&mut gax_options); + } + let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let mut request = statement .into_request() .set_session(self.context.session_name.clone()) @@ -270,6 +303,10 @@ impl ReadWriteTransaction { /// # } /// ``` pub async fn execute_batch_update(&self, batch: BatchDml) -> crate::Result> { + let mut batch = batch; + if self.deadline.is_some() { + self.apply_transaction_timeout(&mut batch.gax_options); + } let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); let statements: Vec = batch @@ -333,11 +370,15 @@ impl ReadWriteTransaction { .set_or_clear_request_options(self.commit_request_options()) .set_or_clear_max_commit_delay(self.max_commit_delay); + // TODO(#4972): make request options configurable + let mut gax_options = GaxRequestOptions::default(); + self.apply_transaction_timeout(&mut gax_options); + let response = self .context .client .spanner - .commit(request, RequestOptions::default()) + .commit(request, gax_options) .await?; let response = @@ -348,10 +389,14 @@ impl ReadWriteTransaction { .set_precommit_token(*new_precommit_token) .set_or_clear_request_options(self.commit_request_options()); + // TODO(#4972): make request options configurable + let mut gax_options = GaxRequestOptions::default(); + self.apply_transaction_timeout(&mut gax_options); + self.context .client .spanner - .commit(retry_commit_req, RequestOptions::default()) + .commit(retry_commit_req, gax_options) .await? } else { response @@ -379,6 +424,40 @@ impl ReadWriteTransaction { Ok(()) } + + fn apply_transaction_timeout(&self, options: &mut GaxRequestOptions) { + if let Some(deadline) = self.deadline { + let inner_policy = options + .retry_policy() + .clone() + .unwrap_or_else(|| Arc::new(google_cloud_gax::retry_policy::Aip194Strict)); + let bounded_policy = TransactionBoundedRetryPolicy { + inner: inner_policy, + deadline, + }; + options.set_retry_policy(bounded_policy); + } + } +} + +/// A retry policy that wraps another policy and bounds the total execution time +/// by a specific transaction deadline. +/// +/// This policy delegates `on_error` to the inner policy but overrides `remaining_time` +/// to ensure that it never exceeds the time left until the transaction deadline. +#[derive(Debug)] +struct TransactionBoundedRetryPolicy { + inner: Arc, + deadline: Instant, +} + +impl RetryPolicy for TransactionBoundedRetryPolicy { + fn on_error(&self, state: &RetryState, error: GaxError) -> RetryResult { + self.inner.on_error(state, error) + } + fn remaining_time(&self, _state: &RetryState) -> Option { + Some(self.deadline.saturating_duration_since(Instant::now())) + } } #[cfg(test)] @@ -480,7 +559,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -547,7 +626,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); let count = tx @@ -584,7 +663,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -629,7 +708,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -688,7 +767,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .begin_transaction(None) .await?; let batch = BatchDml::builder() @@ -733,7 +812,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .begin_transaction(None) .await?; let batch = BatchDml::builder() @@ -791,7 +870,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -841,7 +920,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -893,7 +972,7 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_isolation_level(IsolationLevel::Serializable) .with_read_lock_mode(ReadLockMode::Pessimistic) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); } @@ -917,7 +996,7 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_exclude_txn_from_change_streams(true) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); } @@ -977,7 +1056,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -1051,7 +1130,7 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); @@ -1092,7 +1171,7 @@ mod tests { let tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_max_commit_delay(Duration::new(0, 200_000_000).unwrap()) - .begin_transaction() + .begin_transaction(None) .await .expect("Failed to build transaction"); diff --git a/src/spanner/src/statement.rs b/src/spanner/src/statement.rs index 93dc02b518..02bc26356b 100644 --- a/src/spanner/src/statement.rs +++ b/src/spanner/src/statement.rs @@ -251,6 +251,12 @@ impl Statement { &self.gax_options } + /// Returns a new `Statement` with the given `GaxRequestOptions`. + pub(crate) fn with_gax_options(mut self, options: GaxRequestOptions) -> Self { + self.gax_options = options; + self + } + /// Sets the query mode to use for this statement. /// /// # Example diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index 72cc36b22f..505678e318 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -20,6 +20,7 @@ use crate::read_write_transaction::{ReadWriteTransaction, ReadWriteTransactionBu use crate::transaction_retry_policy::{ BasicTransactionRetryPolicy, TransactionRetryPolicy, backoff_if_aborted, is_aborted, }; +use std::time::Duration as StdDuration; use wkt::Duration; /// A builder for a [TransactionRunner] for a read/write transaction. @@ -46,6 +47,7 @@ use wkt::Duration; pub struct TransactionRunnerBuilder { builder: ReadWriteTransactionBuilder, retry_policy: Box, + timeout: Option, } impl TransactionRunnerBuilder { @@ -53,9 +55,35 @@ impl TransactionRunnerBuilder { Self { builder: ReadWriteTransactionBuilder::new(client), retry_policy: Box::new(BasicTransactionRetryPolicy::default()), + timeout: None, } } + /// Sets the timeout for the entire transaction. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::Spanner; + /// # use std::time::Duration; + /// # async fn run(client: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?; + /// let runner = db_client.read_write_transaction() + /// .with_transaction_timeout(Duration::from_secs(5)) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This timeout applies to the total time spent executing the transaction, including + /// all statements and automatic retries. Each individual RPC within the transaction + /// is automatically assigned a deadline derived from the remaining time of this + /// overall timeout. + pub fn with_transaction_timeout(mut self, timeout: StdDuration) -> Self { + self.timeout = Some(timeout); + self + } + /// Sets the isolation level for the transaction. /// /// # Example @@ -250,6 +278,7 @@ impl TransactionRunnerBuilder { Ok(TransactionRunner { builder: self.builder, retry_policy: self.retry_policy, + timeout: self.timeout, }) } } @@ -258,6 +287,7 @@ impl TransactionRunnerBuilder { pub struct TransactionRunner { builder: ReadWriteTransactionBuilder, retry_policy: Box, + timeout: Option, } impl TransactionRunner { @@ -293,13 +323,14 @@ impl TransactionRunner { let start_time = tokio::time::Instant::now(); let mut attempts: u32 = 0; let backoff = crate::transaction_retry_policy::default_retry_backoff(); + let deadline = self.timeout.map(|t| start_time + t); loop { attempts += 1; let mut current_tx_id = None; let attempt_result = async { - let transaction = self.builder.begin_transaction().await?; + let transaction = self.builder.begin_transaction(deadline).await?; current_tx_id = transaction.transaction_id().ok(); let result = match work(transaction.clone()).await {