diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 12713f9507..55e2c5911b 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -18,7 +18,7 @@ use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::{ MultiUseReadOnlyTransaction, MultiUseReadOnlyTransactionBuilder, ReadContextTransactionSelector, }; -use crate::result_set::{ResultSet, StreamOperation}; +use crate::result_set::{ResultSet, ResultSetParams, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; use google_cloud_gax::backoff_policy::BackoffPolicyArg; @@ -162,7 +162,11 @@ impl BatchReadOnlyTransaction { .context .client .spanner - .partition_query(request, crate::RequestOptions::default()) + .partition_query( + request, + crate::RequestOptions::default(), + self.inner.context.channel_hint, + ) .await?; Ok(response @@ -220,7 +224,11 @@ impl BatchReadOnlyTransaction { .context .client .spanner - .partition_read(request, crate::RequestOptions::default()) + .partition_read( + request, + crate::RequestOptions::default(), + self.inner.context.channel_hint, + ) .await?; Ok(response @@ -366,24 +374,26 @@ impl Partition { req: &crate::model::ExecuteSqlRequest, gax_options: GaxRequestOptions, ) -> crate::Result { + let channel_hint = client.spanner.next_channel_hint(); let stream = client .spanner - .execute_streaming_sql(req.clone(), gax_options.clone()) + .execute_streaming_sql(req.clone(), gax_options.clone(), channel_hint) .send() .await?; - Ok(ResultSet::new( + Ok(ResultSet::new(ResultSetParams { stream, - Some(ReadContextTransactionSelector::Fixed( + transaction_selector: Some(ReadContextTransactionSelector::Fixed( req.transaction.clone().unwrap_or_default(), None, )), - PrecommitTokenTracker::new_noop(), - client.clone(), - req.session.clone(), - StreamOperation::Query(req.clone()), + precommit_token_tracker: PrecommitTokenTracker::new_noop(), + client: client.clone(), + session_name: req.session.clone(), + operation: StreamOperation::Query(req.clone()), + channel_hint, gax_options, - )) + })) } async fn execute_read( @@ -391,24 +401,26 @@ impl Partition { req: &crate::model::ReadRequest, gax_options: GaxRequestOptions, ) -> crate::Result { + let channel_hint = client.spanner.next_channel_hint(); let stream = client .spanner - .streaming_read(req.clone(), gax_options.clone()) + .streaming_read(req.clone(), gax_options.clone(), channel_hint) .send() .await?; - Ok(ResultSet::new( + Ok(ResultSet::new(ResultSetParams { stream, - Some(ReadContextTransactionSelector::Fixed( + transaction_selector: Some(ReadContextTransactionSelector::Fixed( req.transaction.clone().unwrap_or_default(), None, )), - PrecommitTokenTracker::new_noop(), - client.clone(), - req.session.clone(), - StreamOperation::Read(req.clone()), + precommit_token_tracker: PrecommitTokenTracker::new_noop(), + client: client.clone(), + session_name: req.session.clone(), + operation: StreamOperation::Read(req.clone()), + channel_hint, gax_options, - )) + })) } } diff --git a/src/spanner/src/client.rs b/src/spanner/src/client.rs index ac65953e7e..da22272d77 100644 --- a/src/spanner/src/client.rs +++ b/src/spanner/src/client.rs @@ -20,6 +20,7 @@ use crate::model::{ }; use crate::server_streaming::builder; use gaxi::options::{ClientConfig, Credentials}; +use std::sync::atomic::{AtomicUsize, Ordering}; pub use crate::database_client::DatabaseClient; pub use crate::error::SpannerInternalError; @@ -55,8 +56,8 @@ pub use wkt::{DurationError, TimestampError}; /// [Spanner]: https://docs.cloud.google.com/spanner/docs #[derive(Clone, Debug)] pub struct Spanner { - inner: GapicSpanner, - grpc_client: Option, + pub(crate) channels: Vec, + pub(crate) counter: std::sync::Arc, } pub struct Factory; @@ -66,20 +67,19 @@ impl google_cloud_gax::client_builder::internal::ClientFactory for Factory { type Credentials = Credentials; async fn build(self, config: ClientConfig) -> crate::ClientBuilderResult { - let transport = - crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?; - let grpc_client = transport.inner.clone(); + let num_channels = std::env::var("SPANNER_NUM_CHANNELS") + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(4); + + let mut channels = Vec::with_capacity(num_channels); + for _ in 0..num_channels { + channels.push(Channel::create(&config).await?); + } - let inner = if gaxi::options::tracing_enabled(&config) { - GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new( - transport, - )) - } else { - GapicSpanner::from_stub(transport) - }; Ok(Spanner { - inner, - grpc_client: Some(grpc_client), + channels, + counter: std::sync::Arc::new(AtomicUsize::new(0)), }) } } @@ -100,8 +100,10 @@ macro_rules! define_idempotent_rpc { &self, request: $request_type, options: crate::RequestOptions, + channel_hint: usize, ) -> crate::Result<$response_type> { - self.inner + self.get_channel(channel_hint) + .inner .$method() .with_request(request) .with_options(with_default_idempotency(options)) @@ -175,11 +177,23 @@ impl Spanner { // This method is primarily for testing and doesn't fully initialize grpc_client. // For production use, prefer `Spanner::builder().build()`. Self { - inner: GapicSpanner::from_stub(stub), - grpc_client: None, + channels: vec![Channel { + inner: GapicSpanner::from_stub(stub), + grpc_client: None, + }], + counter: std::sync::Arc::new(AtomicUsize::new(0)), } } + pub(crate) fn get_channel(&self, hint: usize) -> &Channel { + let idx = hint % self.channels.len(); + &self.channels[idx] + } + + pub(crate) fn next_channel_hint(&self) -> usize { + self.counter.fetch_add(1, Ordering::Relaxed) + } + define_idempotent_rpc!(create_session, CreateSessionRequest, Session); define_idempotent_rpc!(execute_sql, ExecuteSqlRequest, crate::model::ResultSet); define_idempotent_rpc!( @@ -202,8 +216,10 @@ impl Spanner { &self, request: crate::model::ExecuteSqlRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::ExecuteStreamingSql { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -220,8 +236,10 @@ impl Spanner { &self, request: crate::model::ReadRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::StreamingRead { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -234,8 +252,10 @@ impl Spanner { &self, request: crate::model::BatchWriteRequest, options: crate::RequestOptions, + channel_hint: usize, ) -> builder::BatchWrite { - let grpc = self + let channel = self.get_channel(channel_hint); + let grpc = channel .grpc_client .as_ref() .expect("Streaming RPCs are not supported when using a stub client"); @@ -245,6 +265,32 @@ impl Spanner { } } +#[derive(Clone, Debug)] +pub(crate) struct Channel { + pub(crate) inner: GapicSpanner, + pub(crate) grpc_client: Option, +} + +impl Channel { + pub(crate) async fn create(config: &ClientConfig) -> crate::ClientBuilderResult { + let transport = + crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?; + let grpc_client = transport.inner.clone(); + + let inner = if gaxi::options::tracing_enabled(config) { + GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new( + transport, + )) + } else { + GapicSpanner::from_stub(transport) + }; + Ok(Self { + inner, + grpc_client: Some(grpc_client), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -277,6 +323,50 @@ mod tests { assert_not_impl_any!(Spanner: std::panic::RefUnwindSafe, std::panic::UnwindSafe); } + #[tokio::test] + async fn channel_pool_default_size() { + let mock = MockSpanner::new(); + let (address, _server) = start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + assert_eq!(client.channels.len(), 4); + } + + #[tokio::test] + async fn channel_selection() { + let mock = MockSpanner::new(); + let (address, _server) = start("0.0.0.0:0", mock) + .await + .expect("Failed to start mock server"); + + let client = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await + .expect("Failed to build client"); + + let hint0 = client.next_channel_hint(); + let hint1 = client.next_channel_hint(); + let hint2 = client.next_channel_hint(); + let hint3 = client.next_channel_hint(); + let hint4 = client.next_channel_hint(); + + assert_eq!(hint0 % 4, 0); + assert_eq!(hint1 % 4, 1); + assert_eq!(hint2 % 4, 2); + assert_eq!(hint3 % 4, 3); + assert_eq!(hint4 % 4, 0); + } + #[tokio::test] async fn test_create_session() { // 1. Setup Mock Server @@ -309,7 +399,11 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client - .create_session(req, crate::RequestOptions::default()) + .create_session( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call create_session"); @@ -363,6 +457,7 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client + .get_channel(client.next_channel_hint()) .inner .create_session() .with_request(req) @@ -412,7 +507,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let result_set = client - .execute_sql(req, crate::RequestOptions::default()) + .execute_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call execute_sql"); assert!(result_set.metadata.is_some()); @@ -451,7 +550,11 @@ mod tests { req.session = "test_session".to_string(); let response = client - .execute_batch_dml(req, crate::RequestOptions::default()) + .execute_batch_dml( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call execute_batch_dml"); assert!(response.status.is_some()); @@ -486,7 +589,11 @@ mod tests { req.table = "test_table".to_string(); let result_set = client - .read(req, crate::RequestOptions::default()) + .read( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call read"); assert!(result_set.metadata.is_none()); @@ -520,7 +627,11 @@ mod tests { req.session = "test_session".to_string(); let tx = client - .begin_transaction(req, crate::RequestOptions::default()) + .begin_transaction( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call begin_transaction"); assert_eq!(tx.id, vec![1, 2, 3]); @@ -558,7 +669,11 @@ mod tests { req.session = "test_session".to_string(); let response = client - .commit(req, crate::RequestOptions::default()) + .commit( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call commit"); assert!(response.commit_timestamp.is_some()); @@ -587,7 +702,11 @@ mod tests { req.session = "test_session".to_string(); client - .rollback(req, crate::RequestOptions::default()) + .rollback( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call rollback"); } @@ -629,7 +748,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let mut stream = client - .execute_streaming_sql(req, crate::RequestOptions::default()) + .execute_streaming_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call execute_streaming_sql"); @@ -677,7 +800,11 @@ mod tests { req.columns = vec!["col1".to_string()]; let mut stream = client - .streaming_read(req, crate::RequestOptions::default()) + .streaming_read( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call streaming_read"); @@ -715,7 +842,11 @@ mod tests { req.session = "test_session".to_string(); let mut stream = client - .batch_write(req, crate::RequestOptions::default()) + .batch_write( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call batch_write"); @@ -751,7 +882,11 @@ mod tests { req.sql = "SELECT 1".to_string(); let mut stream = client - .execute_streaming_sql(req, crate::RequestOptions::default()) + .execute_streaming_sql( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .send() .await .expect("Failed to call execute_streaming_sql"); @@ -799,7 +934,11 @@ mod tests { "projects/test-project/instances/test-instance/databases/test-db".to_string(); let session = client - .create_session(req, crate::RequestOptions::default()) + .create_session( + req, + crate::RequestOptions::default(), + client.next_channel_hint(), + ) .await .expect("Failed to call create_session"); @@ -840,7 +979,9 @@ mod tests { let mut options = crate::RequestOptions::default(); options.set_idempotency(false); - let result = client.create_session(req, options).await; + let result = client + .create_session(req, options, client.next_channel_hint()) + .await; // 5. Verify that it failed and did not retry assert!(result.is_err(), "Expected error, got {:?}", result); diff --git a/src/spanner/src/partitioned_dml_transaction.rs b/src/spanner/src/partitioned_dml_transaction.rs index c28a271b17..42b7d083d7 100644 --- a/src/spanner/src/partitioned_dml_transaction.rs +++ b/src/spanner/src/partitioned_dml_transaction.rs @@ -173,13 +173,18 @@ impl PartitionedDmlTransaction { ..Default::default() }; let base_request = statement.into_request(); + let channel_hint = self.client.spanner.next_channel_hint(); // Execute the statement and retry if the transaction is aborted by Spanner. retry_aborted(&*self.retry_policy, || async { let transaction = self .client .spanner - .begin_transaction(begin_request.clone(), crate::RequestOptions::default()) + .begin_transaction( + begin_request.clone(), + crate::RequestOptions::default(), + channel_hint, + ) .await?; let execute_request = base_request @@ -190,10 +195,11 @@ impl PartitionedDmlTransaction { ..Default::default() }); - let stream_builder = self - .client - .spanner - .execute_streaming_sql(execute_request.clone(), crate::RequestOptions::default()); + let stream_builder = self.client.spanner.execute_streaming_sql( + execute_request.clone(), + crate::RequestOptions::default(), + channel_hint, + ); let stream = stream_builder.send().await?; extract_lower_bound_update_count_from_stream(stream).await diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 04e877e90f..c7fa5bae1d 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -16,7 +16,7 @@ use crate::database_client::DatabaseClient; use crate::model::TransactionOptions; use crate::model::transaction_options::ReadOnly; use crate::precommit::PrecommitTokenTracker; -use crate::result_set::{ResultSet, StreamOperation}; +use crate::result_set::{ResultSet, ResultSetParams, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; use std::sync::{Arc, Mutex}; @@ -90,6 +90,7 @@ impl SingleUseReadOnlyTransactionBuilder { .set_single_use(TransactionOptions::default().set_read_only(read_only)); let session_name = self.client.session_name(); + let channel_hint = self.client.spanner.next_channel_hint(); SingleUseReadOnlyTransaction { context: ReadContext { session_name, @@ -100,6 +101,7 @@ impl SingleUseReadOnlyTransactionBuilder { ), precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, + channel_hint, }, } } @@ -277,8 +279,10 @@ impl MultiUseReadOnlyTransactionBuilder { &self, session_name: String, options: TransactionOptions, + channel_hint: usize, ) -> crate::Result { - let response = execute_begin_transaction(&self.client, session_name, options).await?; + let response = + execute_begin_transaction(&self.client, session_name, options, channel_hint).await?; let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); @@ -309,8 +313,10 @@ impl MultiUseReadOnlyTransactionBuilder { let options = TransactionOptions::default().set_read_only(read_only); let session_name = self.client.session_name(); + let channel_hint = self.client.spanner.next_channel_hint(); let selector = if self.explicit_begin { - self.begin(session_name.clone(), options).await? + self.begin(session_name.clone(), options, channel_hint) + .await? } else { ReadContextTransactionSelector::Lazy(Arc::new(Mutex::new( TransactionState::NotStarted(options), @@ -324,6 +330,7 @@ impl MultiUseReadOnlyTransactionBuilder { transaction_selector: selector, precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, + channel_hint, }, }) } @@ -430,6 +437,7 @@ async fn execute_begin_transaction( client: &crate::database_client::DatabaseClient, session_name: String, options: crate::model::TransactionOptions, + channel_hint: usize, ) -> crate::Result { let request = crate::model::BeginTransactionRequest::default() .set_session(session_name) @@ -438,7 +446,7 @@ async fn execute_begin_transaction( // TODO(#4972): make request options configurable client .spanner - .begin_transaction(request, crate::RequestOptions::default()) + .begin_transaction(request, crate::RequestOptions::default(), channel_hint) .await } @@ -484,6 +492,7 @@ impl ReadContextTransactionSelector { &self, client: &crate::database_client::DatabaseClient, session_name: String, + channel_hint: usize, ) -> crate::Result<()> { let Self::Lazy(lazy) = self else { return Ok(()); @@ -497,7 +506,8 @@ impl ReadContextTransactionSelector { options.clone() }; - let response = execute_begin_transaction(client, session_name, options).await?; + let response = + execute_begin_transaction(client, session_name, options, channel_hint).await?; self.update(response.id, response.read_timestamp); Ok(()) @@ -537,6 +547,7 @@ pub(crate) struct ReadContext { pub(crate) transaction_selector: ReadContextTransactionSelector, pub(crate) precommit_token_tracker: PrecommitTokenTracker, pub(crate) transaction_tag: Option, + pub(crate) channel_hint: usize, } impl ReadContext { @@ -571,7 +582,7 @@ impl ReadContext { } self.transaction_selector - .begin_explicitly(&self.client, self.session_name.clone()) + .begin_explicitly(&self.client, self.session_name.clone(), self.channel_hint) .await?; Ok(true) } @@ -583,7 +594,7 @@ macro_rules! execute_stream_with_retry { let stream = match $self .client .spanner - .$rpc_method($request.clone(), $gax_options.clone()) + .$rpc_method($request.clone(), $gax_options.clone(), $self.channel_hint) .send() .await { @@ -594,7 +605,7 @@ macro_rules! execute_stream_with_retry { $self .client .spanner - .$rpc_method($request.clone(), $gax_options.clone()) + .$rpc_method($request.clone(), $gax_options.clone(), $self.channel_hint) .send() .await? } else { @@ -603,15 +614,16 @@ macro_rules! execute_stream_with_retry { } }; - Ok(ResultSet::new( + Ok(ResultSet::new(ResultSetParams { stream, - Some($self.transaction_selector.clone()), - $self.precommit_token_tracker.clone(), - $self.client.clone(), - $self.session_name.clone(), - $operation_variant($request), - $gax_options, - )) + transaction_selector: Some($self.transaction_selector.clone()), + precommit_token_tracker: $self.precommit_token_tracker.clone(), + client: $self.client.clone(), + session_name: $self.session_name.clone(), + operation: $operation_variant($request), + channel_hint: $self.channel_hint, + gax_options: $gax_options, + })) }}; } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 9c021a7b80..0fbb3d89ea 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -119,10 +119,11 @@ impl ReadWriteTransactionBuilder { } // TODO(#4972): make request options configurable + let channel_hint = self.client.spanner.next_channel_hint(); let response = self .client .spanner - .begin_transaction(request, RequestOptions::default()) + .begin_transaction(request, RequestOptions::default(), channel_hint) .await?; let transaction_selector = @@ -137,6 +138,7 @@ impl ReadWriteTransactionBuilder { transaction_selector, precommit_token_tracker: PrecommitTokenTracker::new(), transaction_tag: self.transaction_tag.clone(), + channel_hint, }, seqno: Arc::new(AtomicI64::new(1)), max_commit_delay: self.max_commit_delay, @@ -187,7 +189,7 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_sql(request, gax_options) + .execute_sql(request, gax_options, self.context.channel_hint) .await?; self.context .precommit_token_tracker @@ -291,7 +293,7 @@ impl ReadWriteTransaction { .context .client .spanner - .execute_batch_dml(request, batch.gax_options) + .execute_batch_dml(request, batch.gax_options, self.context.channel_hint) .await; match response_result { @@ -337,7 +339,11 @@ impl ReadWriteTransaction { .context .client .spanner - .commit(request, RequestOptions::default()) + .commit( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; let response = @@ -351,7 +357,11 @@ impl ReadWriteTransaction { self.context .client .spanner - .commit(retry_commit_req, RequestOptions::default()) + .commit( + retry_commit_req, + RequestOptions::default(), + self.context.channel_hint, + ) .await? } else { response @@ -374,7 +384,11 @@ impl ReadWriteTransaction { self.context .client .spanner - .rollback(request, RequestOptions::default()) + .rollback( + request, + RequestOptions::default(), + self.context.channel_hint, + ) .await?; Ok(()) @@ -389,6 +403,7 @@ mod tests { use gaxi::grpc::tonic; use spanner_grpc_mock::google::spanner::v1; use std::fmt::Debug; + use std::sync::Mutex; #[test] fn auto_traits() { @@ -397,23 +412,36 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_commit_retry() { + async fn read_write_transaction_commit_retry() -> anyhow::Result<()> { let mut mock = create_session_mock(); - - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![0, 0, 7], - ..Default::default() - })) - }); + let remotes = Arc::new(Mutex::new(Vec::new())); + + let remotes_clone = remotes.clone(); + mock.expect_begin_transaction() + .once() + .returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + })) + }); // execute_update returns a precommit token. - mock.expect_execute_sql().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_execute_sql().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1"); Ok(tonic::Response::new(v1::ResultSet { @@ -432,7 +460,12 @@ mod tests { // Simulate that commit returns a precommit token in the response. // This would normally not happen, but we test it here to verify // that the commit is retried. - mock.expect_commit().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_commit().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!( req.precommit_token, @@ -459,7 +492,12 @@ mod tests { }); // Second commit retry is automatically issued with the new token - mock.expect_commit().once().returning(|req| { + let remotes_clone = remotes.clone(); + mock.expect_commit().once().returning(move |req| { + remotes_clone + .lock() + .unwrap() + .push(req.remote_addr().expect("remote_addr should be available")); let req = req.into_inner(); assert_eq!( req.precommit_token, @@ -481,17 +519,25 @@ mod tests { let tx = ReadWriteTransactionBuilder::new(db_client.clone()) .begin_transaction() - .await - .expect("Failed to build transaction"); + .await?; let count = tx .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1") - .await - .unwrap(); + .await?; assert_eq!(count, 1); - let timestamp = tx.commit().await.unwrap(); + let timestamp = tx.commit().await?; assert_eq!(timestamp.seconds(), 1001); + + // Verify that all RPCs used the same channel (same remote address) + let remotes = remotes.lock().unwrap(); + assert_eq!(remotes.len(), 4, "Expected exactly 4 RPCs"); + let first = remotes[0]; + for addr in remotes.iter() { + assert_eq!(*addr, first, "All RPCs should use the same gRPC channel"); + } + + Ok(()) } #[tokio::test] diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index e202908e2d..d6f655a1ed 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -71,6 +71,7 @@ pub struct ResultSet { max_buffered_partial_result_sets: usize, retry_count: usize, transaction_selector: Option, + channel_hint: usize, gax_options: GaxRequestOptions, } @@ -89,6 +90,17 @@ pub(crate) enum StreamOperation { Read(crate::model::ReadRequest), } +pub(crate) struct ResultSetParams { + pub stream: PartialResultSetStream, + pub transaction_selector: Option, + pub precommit_token_tracker: PrecommitTokenTracker, + pub client: DatabaseClient, + pub session_name: String, + pub operation: StreamOperation, + pub channel_hint: usize, + pub gax_options: GaxRequestOptions, +} + // The maximum number of PartialResultSets to buffer without a resume token. // Spanner will normally include a resume token with each PartialResultSet. // This maximum is therefore primarily for safety. @@ -96,33 +108,26 @@ const MAX_BUFFERED_PARTIAL_RESULT_SETS: usize = 10; impl ResultSet { /// Creates a new result set. - pub(crate) fn new( - stream: PartialResultSetStream, - transaction_selector: Option, - precommit_token_tracker: PrecommitTokenTracker, - client: DatabaseClient, - session_name: String, - operation: StreamOperation, - gax_options: GaxRequestOptions, - ) -> Self { - let gax_options = Self::apply_defaults(gax_options); + pub(crate) fn new(params: ResultSetParams) -> Self { + let gax_options = Self::apply_defaults(params.gax_options); Self { - stream: Some(stream), + stream: Some(params.stream), buffered_values: Vec::new(), chunked: false, seen_last: false, ready_rows: VecDeque::new(), metadata: None, - precommit_token_tracker, - client, - session_name, - operation, + precommit_token_tracker: params.precommit_token_tracker, + client: params.client, + session_name: params.session_name, + operation: params.operation, last_resume_token: Bytes::new(), partial_result_sets_buffer: VecDeque::new(), safe_to_retry: true, max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, retry_count: 0, - transaction_selector, + transaction_selector: params.transaction_selector, + channel_hint: params.channel_hint, stats: None, gax_options, } @@ -345,7 +350,7 @@ impl ResultSet { self.transaction_selector .as_ref() .unwrap() - .begin_explicitly(&self.client, self.session_name.clone()) + .begin_explicitly(&self.client, self.session_name.clone(), self.channel_hint) .await?; self.partial_result_sets_buffer.clear(); @@ -495,7 +500,7 @@ impl ResultSet { let stream = self .client .spanner - .execute_streaming_sql(req.clone(), self.gax_options.clone()) + .execute_streaming_sql(req.clone(), self.gax_options.clone(), self.channel_hint) .send() .await?; self.stream = Some(stream); @@ -508,7 +513,7 @@ impl ResultSet { let stream = self .client .spanner - .streaming_read(req.clone(), self.gax_options.clone()) + .streaming_read(req.clone(), self.gax_options.clone(), self.channel_hint) .send() .await?; self.stream = Some(stream); diff --git a/src/spanner/src/session_maintainer.rs b/src/spanner/src/session_maintainer.rs index ff32116662..d03ca001c2 100644 --- a/src/spanner/src/session_maintainer.rs +++ b/src/spanner/src/session_maintainer.rs @@ -134,7 +134,9 @@ impl ManagedSessionMaintainer { .set_creator_role(database_role), ); - spanner.create_session(request, options.clone()).await + spanner + .create_session(request, options.clone(), spanner.next_channel_hint()) + .await } async fn maintenance_loop( diff --git a/src/spanner/src/write_only_transaction.rs b/src/spanner/src/write_only_transaction.rs index df381770e0..33b6577f68 100644 --- a/src/spanner/src/write_only_transaction.rs +++ b/src/spanner/src/write_only_transaction.rs @@ -246,6 +246,7 @@ impl WriteOnlyTransaction { let client = self.client; let session_name = self.session_name.clone(); let previous_transaction_id = Arc::new(Mutex::new(Bytes::new())); + let channel_hint = client.spanner.next_channel_hint(); retry_aborted(&*self.retry_policy, || { let client = client.clone(); @@ -275,7 +276,7 @@ impl WriteOnlyTransaction { let tx = client .spanner - .begin_transaction(begin_req, crate::RequestOptions::default()) + .begin_transaction(begin_req, crate::RequestOptions::default(), channel_hint) .await?; *previous_transaction_id.lock().unwrap() = tx.id.clone(); @@ -289,7 +290,7 @@ impl WriteOnlyTransaction { let response = client .spanner - .commit(commit_req, crate::RequestOptions::default()) + .commit(commit_req, crate::RequestOptions::default(), channel_hint) .await?; // If a commit_response with a precommit_token is returned, then we need to @@ -302,7 +303,11 @@ impl WriteOnlyTransaction { .set_precommit_token(new_token); client .spanner - .commit(retry_commit_req, crate::RequestOptions::default()) + .commit( + retry_commit_req, + crate::RequestOptions::default(), + channel_hint, + ) .await } else { Ok(response) @@ -361,6 +366,7 @@ impl WriteOnlyTransaction { .set_request_options(req_options) .set_or_clear_max_commit_delay(self.max_commit_delay); let client = self.client; + let channel_hint = client.spanner.next_channel_hint(); retry_aborted(&*self.retry_policy, || { let client = client.clone(); @@ -369,7 +375,7 @@ impl WriteOnlyTransaction { async move { client .spanner - .commit(request, crate::RequestOptions::default()) + .commit(request, crate::RequestOptions::default(), channel_hint) .await } })