Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 276 additions & 0 deletions src/spanner/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ mod tests {
use super::*;
use crate::model::CreateSessionRequest;
use crate::result_set::tests::adapt;
use gaxi::grpc::tonic::MetadataMap;
use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status};
use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
use google_cloud_gax::backoff_policy::BackoffPolicy;
Expand All @@ -258,9 +259,15 @@ mod tests {
use google_cloud_test_macros::tokio_test_no_panics;
use spanner_grpc_mock::google::rpc as mock_rpc;
use spanner_grpc_mock::google::spanner::v1 as mock_v1;
use spanner_grpc_mock::google::spanner::v1::CommitResponse;
use spanner_grpc_mock::google::spanner::v1::ResultSet;
use spanner_grpc_mock::google::spanner::v1::ResultSetStats;
use spanner_grpc_mock::google::spanner::v1::Session;
use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount;
use spanner_grpc_mock::{MockSpanner, start};
use static_assertions::{assert_impl_all, assert_not_impl_any};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;

mockall::mock! {
Expand Down Expand Up @@ -1091,6 +1098,275 @@ mod tests {
Ok(())
}

fn parse_timeout(metadata: &MetadataMap) -> u64 {
let timeout = metadata
.get("grpc-timeout")
.expect("grpc-timeout header should be present");
let timeout_str = timeout
.to_str()
.expect("grpc-timeout should be a valid string");
if timeout_str.ends_with('u') {
timeout_str
.trim_end_matches('u')
.parse()
.expect("valid u64")
} else if timeout_str.ends_with('m') {
timeout_str
.trim_end_matches('m')
.parse::<u64>()
.expect("valid u64")
* 1000
} else if timeout_str.ends_with('n') {
timeout_str
.trim_end_matches('n')
.parse::<u64>()
.expect("valid u64")
/ 1000
} else {
panic!("Unknown timeout unit in {}", timeout_str);
}
}

#[tokio_test_no_panics]
async fn transaction_timeout_respected() -> anyhow::Result<()> {
use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
use spanner_grpc_mock::google::spanner::v1::Transaction;

// 1. Setup Mock Server
let mut mock = MockSpanner::new();

mock.expect_create_session().returning(|_| {
Ok(Response::new(Session {
name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
..Default::default()
}))
});

mock.expect_begin_transaction().returning(|_| {
Ok(Response::new(Transaction {
id: vec![1, 2, 3],
..Default::default()
}))
});

mock.expect_commit().once().returning(|_| {
Ok(Response::new(CommitResponse {
commit_timestamp: Some(prost_types::Timestamp {
seconds: 12345,
nanos: 0,
}),
..Default::default()
}))
});

// Mock execute_sql to first fail and then succeed, checking timeout header on both
let mut seq = mockall::Sequence::new();

mock.expect_execute_sql()
.once()
.in_sequence(&mut seq)
.returning(|req| {
let timeout_val = parse_timeout(req.metadata());
assert!(
timeout_val <= 100000,
"Expected timeout to be <= 100ms, got {}",
timeout_val
);
Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded"))
});

mock.expect_execute_sql()
.once()
.in_sequence(&mut seq)
.returning(|req| {
let timeout_val = parse_timeout(req.metadata());
assert!(
timeout_val <= 100000,
"Expected timeout to be <= 100ms, got {}",
timeout_val
);

let res = ResultSet {
stats: Some(ResultSetStats {
row_count: Some(RowCount::RowCountExact(1)),
..Default::default()
}),
..Default::default()
};
Ok(Response::new(res))
});

// 2. Initialize Client
let (address, _server) = start("127.0.0.1:0", mock).await?;
let client = Spanner::builder()
.with_endpoint(address)
.with_credentials(Anonymous::new().build())
.build()
.await?;
let db = client
.database_client("projects/p/instances/i/databases/d")
.build()
.await?;

// 3. Setup Transaction Runner with 100ms timeout
let runner = db
.read_write_transaction()
.with_transaction_timeout(Duration::from_millis(100))
.build()
.await?;

// 4. Run transaction and expect success after retry
let result = runner
.run(async |tx| {
let mut mock_backoff = MockBackoffPolicy::new();
mock_backoff
.expect_on_failure()
.times(1)
.returning(|_| Duration::from_nanos(1));

let retry_policy = Aip194Strict.continue_on_too_many_requests();

let stmt = Statement::builder("SELECT 1")
.with_retry_policy(retry_policy)
.with_backoff_policy(mock_backoff)
.build();
tx.execute_update(stmt).await?;
Ok(())
})
.await;

result.expect("Transaction should have succeeded");

Ok(())
}

#[tokio::test]
async fn transaction_timeout_ticks_down() -> anyhow::Result<()> {
use spanner_grpc_mock::google::spanner::v1::Transaction;

let mut mock = MockSpanner::new();

mock.expect_create_session().returning(|_| {
Ok(Response::new(Session {
name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
..Default::default()
}))
});

let mut seq = mockall::Sequence::new();

mock.expect_begin_transaction()
.once()
.in_sequence(&mut seq)
.returning(|_| {
Ok(Response::new(Transaction {
id: vec![1],
..Default::default()
}))
});

let previous_timeout = Arc::new(AtomicU64::new(0));
let prev_clone1 = previous_timeout.clone();
mock.expect_execute_sql()
.once()
.in_sequence(&mut seq)
.returning(move |req| {
let timeout_val = parse_timeout(req.metadata());
assert!(
timeout_val <= 500000,
"Expected timeout to be <= 500ms, got {}",
timeout_val
);
prev_clone1.store(timeout_val, Ordering::SeqCst);
Err(Status::new(GrpcCode::Aborted, "Aborted"))
});

// Second attempt: Checks that timeout is <= previous
mock.expect_begin_transaction()
.once()
.in_sequence(&mut seq)
.returning(|_| {
Ok(Response::new(Transaction {
id: vec![2],
..Default::default()
}))
});

let prev_clone2 = previous_timeout.clone();
mock.expect_execute_sql()
.once()
.in_sequence(&mut seq)
.returning(move |req| {
let timeout_val = parse_timeout(req.metadata());
let prev = prev_clone2.load(Ordering::SeqCst);
assert!(
timeout_val <= prev,
"Timeout should tick down between attempts or be equal, got {} and {}",
timeout_val,
prev
);
prev_clone2.store(timeout_val, Ordering::SeqCst); // store for next check

let res = ResultSet {
stats: Some(ResultSetStats {
row_count: Some(RowCount::RowCountExact(1)),
..Default::default()
}),
..Default::default()
};
Ok(Response::new(res))
});

let prev_clone3 = previous_timeout.clone();
mock.expect_commit().once().returning(move |req| {
let timeout_val = parse_timeout(req.metadata());
let prev = prev_clone3.load(Ordering::SeqCst);
assert!(
timeout_val < prev,
"Timeout should be smaller for commit, got {} and {}",
timeout_val,
prev
);

Ok(Response::new(CommitResponse {
commit_timestamp: Some(prost_types::Timestamp {
seconds: 12345,
nanos: 0,
}),
..Default::default()
}))
});

let (address, _server) = start("127.0.0.1:0", mock).await?;
let client = Spanner::builder()
.with_endpoint(address)
.with_credentials(Anonymous::new().build())
.build()
.await?;
let db = client
.database_client("projects/p/instances/i/databases/d")
.build()
.await?;

let runner = db
.read_write_transaction()
.with_transaction_timeout(Duration::from_millis(500))
.build()
.await?;

let result = runner
.run(async |tx| {
let stmt = Statement::builder("SELECT 1").build();
tx.execute_update(stmt).await?;
Ok(())
})
.await;

result.expect("Transaction should have succeeded");

Ok(())
}

#[test]
fn test_parse_emulator_endpoint() {
assert_eq!(
Expand Down
Loading
Loading