Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spanner/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ mod tests {
.await?;

// 5. Verify success after retry
assert_eq!(result, 1);
assert_eq!(result.result, 1);

Ok(())
}
Expand Down
58 changes: 46 additions & 12 deletions src/spanner/src/read_write_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::database_client::DatabaseClient;
use crate::error::internal_error;
use crate::model::BeginTransactionRequest;
use crate::model::CommitRequest;
use crate::model::CommitResponse;
use crate::model::ExecuteBatchDmlRequest;
use crate::model::RollbackRequest;
use crate::model::TransactionOptions;
Expand Down Expand Up @@ -53,6 +54,7 @@ pub(crate) struct ReadWriteTransactionBuilder {
transaction_tag: Option<String>,
max_commit_delay: Option<Duration>,
pub(crate) session_name: String,
return_commit_stats: bool,
commit_priority: Priority,
}

Expand All @@ -65,6 +67,7 @@ impl ReadWriteTransactionBuilder {
transaction_tag: None,
max_commit_delay: None,
session_name,
return_commit_stats: false,
commit_priority: Priority::Unspecified,
}
}
Expand Down Expand Up @@ -114,6 +117,11 @@ impl ReadWriteTransactionBuilder {
self
}

pub(crate) fn with_return_commit_stats(mut self, return_stats: bool) -> Self {
self.return_commit_stats = return_stats;
self
}

pub(crate) async fn begin_transaction(
&self,
deadline: Option<Instant>,
Expand Down Expand Up @@ -156,6 +164,7 @@ impl ReadWriteTransactionBuilder {
},
seqno: Arc::new(AtomicI64::new(1)),
max_commit_delay: self.max_commit_delay,
return_commit_stats: self.return_commit_stats,
deadline,
commit_priority: self.commit_priority.clone(),
})
Expand All @@ -169,6 +178,7 @@ pub struct ReadWriteTransaction {
pub(crate) deadline: Option<Instant>,
seqno: Arc<AtomicI64>,
max_commit_delay: Option<Duration>,
return_commit_stats: bool,
commit_priority: Priority,
}

Expand Down Expand Up @@ -360,15 +370,16 @@ impl ReadWriteTransaction {
}

/// Commits the transaction.
pub(crate) async fn commit(self) -> crate::Result<wkt::Timestamp> {
pub(crate) async fn commit(self) -> crate::Result<CommitResponse> {
let transaction_id = self.transaction_id()?;
let precommit_token = self.context.precommit_token_tracker.get();
let request = CommitRequest::default()
.set_session(self.context.session_name.clone())
.set_transaction_id(transaction_id.clone())
.set_or_clear_precommit_token(precommit_token)
.set_or_clear_request_options(self.commit_request_options())
.set_or_clear_max_commit_delay(self.max_commit_delay);
.set_or_clear_request_options(self.context.amend_request_options(None))
.set_or_clear_max_commit_delay(self.max_commit_delay)
.set_return_commit_stats(self.return_commit_stats);

// TODO(#4972): make request options configurable
let mut gax_options = GaxRequestOptions::default();
Expand Down Expand Up @@ -402,10 +413,7 @@ impl ReadWriteTransaction {
response
};

let timestamp = response
.commit_timestamp
.ok_or_else(|| internal_error("No commit timestamp returned"))?;
Ok(timestamp)
Ok(response)
}

/// Rolls back the transaction.
Expand Down Expand Up @@ -570,7 +578,13 @@ mod tests {
assert_eq!(count, 1);

let timestamp = tx.commit().await.unwrap();
assert_eq!(timestamp.seconds(), 1001);
assert_eq!(
timestamp
.commit_timestamp
.expect("Commit timestamp should be present")
.seconds(),
1001
);
}

#[tokio::test]
Expand Down Expand Up @@ -636,7 +650,12 @@ mod tests {
assert_eq!(count, 1);

let ts = tx.commit().await.expect("Failed to commit");
assert_eq!(ts.seconds(), 123456789);
assert_eq!(
ts.commit_timestamp
.expect("Commit timestamp should be present")
.seconds(),
123456789
);
}

#[tokio::test]
Expand Down Expand Up @@ -1066,7 +1085,12 @@ mod tests {
.expect("Failed to execute update");
}
let ts = tx.commit().await.expect("Failed to commit transaction");
assert_eq!(ts.seconds(), 12345);
assert_eq!(
ts.commit_timestamp
.expect("Commit timestamp should be present")
.seconds(),
12345
);
}

#[tokio::test]
Expand Down Expand Up @@ -1135,7 +1159,12 @@ mod tests {
.expect("Failed to build transaction");

let ts = tx.commit().await.expect("Failed to commit transaction");
assert_eq!(ts.seconds(), 9999);
assert_eq!(
ts.commit_timestamp
.expect("Commit timestamp should be present")
.seconds(),
9999
);
}

#[tokio::test]
Expand Down Expand Up @@ -1176,6 +1205,11 @@ mod tests {
.expect("Failed to build transaction");

let ts = tx.commit().await.expect("Failed to commit");
assert_eq!(ts.seconds(), 123456789);
assert_eq!(
ts.commit_timestamp
.expect("Commit timestamp should be present")
.seconds(),
123456789
);
}
}
115 changes: 107 additions & 8 deletions src/spanner/src/transaction_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use crate::database_client::DatabaseClient;
use crate::model::CommitResponse;
use crate::model::request_options::Priority;
use crate::model::transaction_options::IsolationLevel;
use crate::model::transaction_options::read_write::ReadLockMode;
Expand Down Expand Up @@ -228,6 +229,37 @@ impl TransactionRunnerBuilder {
self
}

/// Sets whether to return commit stats for the transaction.
///
/// # Example
/// ```
/// # use google_cloud_spanner::client::{Spanner, Statement};
/// # async fn run_tx(client: Spanner) -> Result<(), google_cloud_spanner::Error> {
/// # let db_client = client.database_client("projects/p/instances/i/databases/d").build().await?;
/// let runner = db_client.read_write_transaction()
/// .with_return_commit_stats(true)
/// .build()
/// .await?;
///
/// let result = runner.run(async |transaction| {
/// let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
/// transaction.execute_update(statement).await?;
/// Ok(42)
/// }).await?;
///
/// if let Some(stats) = result.commit_response.commit_stats {
/// println!("Mutation count: {}", stats.mutation_count);
/// }
/// # Ok(())
/// # }
/// ```
///
/// See also: <https://docs.cloud.google.com/spanner/docs/commit-statistics>
pub fn with_return_commit_stats(mut self, return_stats: bool) -> Self {
self.builder = self.builder.with_return_commit_stats(return_stats);
self
}

/// Sets the retry policy for the transaction.
///
/// # Example
Expand Down Expand Up @@ -283,6 +315,16 @@ impl TransactionRunnerBuilder {
}
}

/// Result of a read/write transaction executed by a [TransactionRunner].
#[derive(Debug)]
#[non_exhaustive]
pub struct TransactionResult<T> {
/// The result returned by the closure executed within the transaction.
pub result: T,
/// The response from the commit RPC.
pub commit_response: CommitResponse,
}

/// A runner for read/write transactions. Aborted transactions are automatically retried.
pub struct TransactionRunner {
builder: ReadWriteTransactionBuilder,
Expand All @@ -304,7 +346,7 @@ impl TransactionRunner {
/// let result = runner.run(async |transaction| {
/// let statement = Statement::builder("UPDATE MyTable SET MyColumn = 'MyValue' WHERE Id = 1").build();
/// transaction.execute_update(statement).await?;
/// Ok(42) // This will be returned by runner.run()
/// Ok(42)
/// }).await?;
/// # Ok(())
/// # }
Expand All @@ -316,7 +358,7 @@ impl TransactionRunner {
/// The transaction is automatically committed if the closure returns `Ok`.
/// If the closure returns `Err`, the transaction will be rolled back and
/// the error will be propagated.
pub async fn run<T, F>(mut self, mut work: F) -> crate::Result<T>
pub async fn run<T, F>(mut self, mut work: F) -> crate::Result<TransactionResult<T>>
where
F: std::ops::AsyncFnMut(ReadWriteTransaction) -> crate::Result<T>,
{
Expand Down Expand Up @@ -344,8 +386,11 @@ impl TransactionRunner {
}
};

transaction.commit().await?;
Ok::<T, crate::Error>(result)
let commit_response = transaction.commit().await?;
Ok::<TransactionResult<T>, crate::Error>(TransactionResult {
result,
commit_response,
})
}
.await;

Expand Down Expand Up @@ -377,6 +422,8 @@ mod tests {
use crate::transaction_retry_policy::tests::create_aborted_status;
use gaxi::grpc::tonic;
use spanner_grpc_mock::google::spanner::v1;
use spanner_grpc_mock::google::spanner::v1::CommitResponse;
use spanner_grpc_mock::google::spanner::v1::commit_response::CommitStats;
use spanner_grpc_mock::google::spanner::v1::transaction_options::Mode;

fn expect_begin_transaction(
Expand Down Expand Up @@ -413,6 +460,7 @@ mod tests {
Ok(count)
})
.await
.map(|res| res.result)
}

fn commit_response() -> Result<tonic::Response<v1::CommitResponse>, tonic::Status> {
Expand Down Expand Up @@ -471,6 +519,57 @@ mod tests {
assert_eq!(res, 1);
}

#[tokio::test]
async fn run_success_with_commit_stats() {
let mut mock = create_session_mock();

expect_begin_transaction(&mut mock, 1, vec![1, 2, 3]);

mock.expect_execute_sql().once().returning(|req| {
let req = req.into_inner();
assert_eq!(req.sql, "UPDATE Users SET active = true");
row_count_exact_response(1)
});

mock.expect_commit().once().returning(|req| {
let req = req.into_inner();
assert!(req.return_commit_stats);
Ok(tonic::Response::new(CommitResponse {
commit_timestamp: Some(prost_types::Timestamp {
seconds: 123456789,
nanos: 0,
}),
commit_stats: Some(CommitStats { mutation_count: 5 }),
..Default::default()
}))
});

let (db_client, _server) = setup_db_client(mock).await;
let runner = TransactionRunnerBuilder::new(db_client)
.with_return_commit_stats(true)
.build()
.await
.unwrap();

let res = runner
.run(async |tx| {
let count = tx.execute_update("UPDATE Users SET active = true").await?;
Ok(count)
})
.await
.unwrap();

assert_eq!(res.result, 1);
assert!(res.commit_response.commit_stats.is_some());
assert_eq!(
res.commit_response
.commit_stats
.expect("Commit stats should be present")
.mutation_count,
5
);
}

#[tokio::test]
async fn run_with_aborted_retry() -> anyhow::Result<()> {
let mut mock = create_session_mock();
Expand Down Expand Up @@ -743,7 +842,7 @@ mod tests {
.await
.expect("transaction failed");

assert_eq!(res, vec![5]);
assert_eq!(res.result, vec![5]);
assert_eq!(attempt_counter, 2);
}

Expand Down Expand Up @@ -803,7 +902,7 @@ mod tests {
})
.await?;

assert_eq!(res, 5);
assert_eq!(res.result, 5);

Ok(())
}
Expand Down Expand Up @@ -844,7 +943,7 @@ mod tests {
})
.await?;

assert_eq!(res, 5);
assert_eq!(res.result, 5);

Ok(())
}
Expand Down Expand Up @@ -883,7 +982,7 @@ mod tests {
Ok(count)
})
.await?;
assert_eq!(res, 1);
assert_eq!(res.result, 1);
Ok(())
}
}
Loading
Loading