diff --git a/src/common/security.rs b/src/common/security.rs index 89e074b3..7fe41328 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -1,10 +1,15 @@ // Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. +use std::collections::hash_map::DefaultHasher; +use std::fs; use std::fs::File; +use std::hash::Hash; +use std::hash::Hasher; use std::io::Read; use std::path::Path; use std::path::PathBuf; use std::time::Duration; +use std::time::SystemTime; use log::info; use regex::Regex; @@ -43,12 +48,12 @@ fn load_pem_file(tag: &str, path: &Path) -> Result> { /// Manages the TLS protocol #[derive(Default)] pub struct SecurityManager { - /// The PEM encoding of the server’s CA certificates. - ca: Vec, - /// The PEM encoding of the server’s certificate chain. - cert: Vec, + /// The path to the PEM encoding of the server’s CA certificates. + ca_path: Option, + /// The path to the PEM encoding of the server’s certificate chain. + cert_path: Option, /// The path to the file that contains the PEM encoding of the server’s private key. - key: PathBuf, + key_path: Option, } impl SecurityManager { @@ -58,15 +63,35 @@ impl SecurityManager { cert_path: impl AsRef, key_path: impl Into, ) -> Result { + let ca_path = ca_path.as_ref().to_path_buf(); + let cert_path = cert_path.as_ref().to_path_buf(); let key_path = key_path.into(); + check_pem_file("ca", &ca_path)?; + check_pem_file("certificate", &cert_path)?; check_pem_file("private key", &key_path)?; Ok(SecurityManager { - ca: load_pem_file("ca", ca_path.as_ref())?, - cert: load_pem_file("certificate", cert_path.as_ref())?, - key: key_path, + ca_path: Some(ca_path), + cert_path: Some(cert_path), + key_path: Some(key_path), }) } + pub(crate) fn tls_configured(&self) -> bool { + self.ca_path.is_some() + } + + pub(crate) fn connection_cache_key(&self) -> Result> { + if !self.tls_configured() { + return Ok(None); + } + + let mut hasher = DefaultHasher::new(); + file_signature(self.ca_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + file_signature(self.cert_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + file_signature(self.key_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher); + Ok(Some(hasher.finish())) + } + /// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection. pub async fn connect( &self, @@ -78,7 +103,7 @@ impl SecurityManager { Factory: FnOnce(Channel) -> Client, { info!("connect to rpc server at endpoint: {:?}", addr); - let channel = if !self.ca.is_empty() { + let channel = if self.tls_configured() { self.tls_channel(addr).await? } else { self.default_channel(addr).await? @@ -89,18 +114,37 @@ impl SecurityManager { } async fn tls_channel(&self, addr: &str) -> Result { + let (ca, cert, key) = self.load_tls_materials()?; let addr = "https://".to_string() + &SCHEME_REG.replace(addr, ""); let builder = self.endpoint(addr.to_string())?; let tls = ClientTlsConfig::new() - .ca_certificate(Certificate::from_pem(&self.ca)) - .identity(Identity::from_pem( - &self.cert, - load_pem_file("private key", &self.key)?, - )); + .ca_certificate(Certificate::from_pem(ca)) + .identity(Identity::from_pem(cert, key)); let builder = builder.tls_config(tls)?; Ok(builder) } + fn load_tls_materials(&self) -> Result<(Vec, Vec, Vec)> { + let ca_path = self + .ca_path + .as_ref() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + let cert_path = self + .cert_path + .as_ref() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + let key_path = self + .key_path + .as_ref() + .ok_or_else(|| internal_err!("TLS is not configured"))?; + + Ok(( + load_pem_file("ca", ca_path)?, + load_pem_file("certificate", cert_path)?, + load_pem_file("private key", key_path)?, + )) + } + async fn default_channel(&self, addr: &str) -> Result { let addr = "http://".to_string() + &SCHEME_REG.replace(addr, ""); self.endpoint(addr) @@ -114,6 +158,17 @@ impl SecurityManager { } } +fn file_signature(path: &Path) -> Result<(u64, Option)> { + let metadata = fs::metadata(path) + .map_err(|e| internal_err!("failed to stat {}: {:?}", path.display(), e))?; + let modified = metadata.modified().ok().and_then(|t: SystemTime| { + t.duration_since(SystemTime::UNIX_EPOCH) + .ok() + .map(|d| d.as_nanos()) + }); + Ok((metadata.len(), modified)) +} + #[cfg(test)] mod tests { use std::fs::File; @@ -140,9 +195,43 @@ mod tests { let key_path: PathBuf = format!("{}", example_pem.display()).into(); let ca_path: PathBuf = format!("{}", example_ca.display()).into(); let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap(); - assert_eq!(mgr.ca, vec![0]); - assert_eq!(mgr.cert, vec![1]); - let key = load_pem_file("private key", &key_path).unwrap(); + assert!(mgr.tls_configured()); + let (ca, cert, key) = mgr.load_tls_materials().unwrap(); + assert_eq!(ca, vec![0]); + assert_eq!(cert, vec![1]); assert_eq!(key, vec![2]); } + + #[test] + fn test_security_reload() { + let temp = tempfile::tempdir().unwrap(); + let example_ca = temp.path().join("ca"); + let example_cert = temp.path().join("cert"); + let example_pem = temp.path().join("key"); + for (id, f) in [&example_ca, &example_cert, &example_pem] + .iter() + .enumerate() + { + File::create(f).unwrap().write_all(&[id as u8]).unwrap(); + } + + let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap(); + let first = mgr.load_tls_materials().unwrap(); + let key1 = mgr.connection_cache_key().unwrap(); + + File::create(&example_ca).unwrap().write_all(&[9]).unwrap(); + File::create(&example_cert) + .unwrap() + .write_all(&[8]) + .unwrap(); + File::create(&example_pem).unwrap().write_all(&[7]).unwrap(); + + let second = mgr.load_tls_materials().unwrap(); + let key2 = mgr.connection_cache_key().unwrap(); + assert_ne!(first, second); + assert_eq!(second.0, vec![9]); + assert_eq!(second.1, vec![8]); + assert_eq!(second.2, vec![7]); + assert_ne!(key1, key2); + } } diff --git a/src/pd/client.rs b/src/pd/client.rs index 05b9c07c..10d6b3a8 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -214,11 +214,17 @@ pub trait PdClient: Send + Sync + 'static { pub struct PdRpcClient { pd: Arc>, kv_connect: KvC, - kv_client_cache: Arc>>, + kv_client_cache: Arc>>>, enable_codec: bool, region_cache: RegionCache>, } +#[derive(Clone)] +struct CachedKvClient { + cache_key: Option, + client: Client, +} + #[async_trait] impl PdClient for PdRpcClient { type KvClient = KvC::KvClient; @@ -338,16 +344,22 @@ impl PdRpcClient { } async fn kv_client(&self, address: &str) -> Result { - if let Some(client) = self.kv_client_cache.read().await.get(address) { - return Ok(client.clone()); + let cache_key = self.kv_connect.connection_cache_key()?; + if let Some(cached) = self.kv_client_cache.read().await.get(address) { + if cached.cache_key == cache_key { + return Ok(cached.client.clone()); + } }; info!("connect to tikv endpoint: {:?}", address); match self.kv_connect.connect(address).await { Ok(client) => { - self.kv_client_cache - .write() - .await - .insert(address.to_owned(), client.clone()); + self.kv_client_cache.write().await.insert( + address.to_owned(), + CachedKvClient { + cache_key, + client: client.clone(), + }, + ); Ok(client) } Err(e) => Err(e), @@ -364,11 +376,18 @@ fn make_key_range(start_key: Vec, end_key: Vec) -> kvrpcpb::KeyRange { #[cfg(test)] pub mod test { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + use async_trait::async_trait; use futures::executor; use futures::executor::block_on; use super::*; use crate::mock::*; + use crate::pd::RetryClient; + use crate::store::KvConnect; + use crate::Config; #[tokio::test] async fn test_kv_client_caching() { @@ -384,6 +403,109 @@ pub mod test { assert_eq!(kv2.addr, kv3.addr); } + #[tokio::test] + async fn test_kv_client_cache_hits_when_key_is_stable() { + #[derive(Clone)] + struct CountingConnect { + connects: Arc, + } + + #[async_trait] + impl KvConnect for CountingConnect { + type KvClient = MockKvClient; + + async fn connect(&self, address: &str) -> Result { + self.connects.fetch_add(1, Ordering::SeqCst); + let mut client = MockKvClient::default(); + client.addr = address.to_owned(); + Ok(client) + } + + fn connection_cache_key(&self) -> Result> { + Ok(Some(0)) + } + } + + let connects = Arc::new(AtomicUsize::new(0)); + let connects_clone = connects.clone(); + let client = PdRpcClient::new( + Config::default(), + move |_| CountingConnect { + connects: connects_clone.clone(), + }, + |sm| async move { + Ok(RetryClient::new_with_cluster( + sm, + Config::default().timeout, + MockCluster, + )) + }, + false, + ) + .await + .unwrap(); + + let kv1 = client.kv_client("foo").await.unwrap(); + let kv2 = client.kv_client("foo").await.unwrap(); + assert_eq!(kv1.addr, "foo"); + assert_eq!(kv2.addr, "foo"); + assert_eq!(connects.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_kv_client_cache_invalidate_on_key_change() { + #[derive(Clone)] + struct CountingConnect { + connects: Arc, + cache_key: Arc, + } + + #[async_trait] + impl KvConnect for CountingConnect { + type KvClient = MockKvClient; + + async fn connect(&self, address: &str) -> Result { + self.connects.fetch_add(1, Ordering::SeqCst); + let mut client = MockKvClient::default(); + client.addr = address.to_owned(); + Ok(client) + } + + fn connection_cache_key(&self) -> Result> { + Ok(Some(self.cache_key.load(Ordering::SeqCst) as u64)) + } + } + + let connects = Arc::new(AtomicUsize::new(0)); + let cache_key = Arc::new(AtomicUsize::new(1)); + let connects_clone = connects.clone(); + let cache_key_clone = cache_key.clone(); + let client = PdRpcClient::new( + Config::default(), + move |_| CountingConnect { + connects: connects_clone.clone(), + cache_key: cache_key_clone.clone(), + }, + |sm| async move { + Ok(RetryClient::new_with_cluster( + sm, + Config::default().timeout, + MockCluster, + )) + }, + false, + ) + .await + .unwrap(); + + let kv1 = client.kv_client("foo").await.unwrap(); + cache_key.store(2, Ordering::SeqCst); + let kv2 = client.kv_client("foo").await.unwrap(); + assert_eq!(kv1.addr, "foo"); + assert_eq!(kv2.addr, "foo"); + assert_eq!(connects.load(Ordering::SeqCst), 2); + } + #[test] fn test_group_keys_by_region() { let client = MockPdClient::default(); diff --git a/src/request/mod.rs b/src/request/mod.rs index c8fd07be..0cde241b 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -90,21 +90,25 @@ impl RetryOptions { mod test { use std::any::Any; use std::iter; - use std::sync::atomic::AtomicUsize; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; + use async_trait::async_trait; use tonic::transport::Channel; use super::*; use crate::mock::MockKvClient; use crate::mock::MockPdClient; + use crate::proto::keyspacepb; use crate::proto::kvrpcpb; + use crate::proto::metapb::{self, RegionEpoch}; use crate::proto::pdpb::Timestamp; use crate::proto::tikvpb::tikv_client::TikvClient; - use crate::region::RegionWithLeader; + use crate::region::{RegionId, RegionVerId, RegionWithLeader, StoreId}; use crate::store::region_stream_for_keys; use crate::store::HasRegionError; + use crate::store::{RegionStore, Store}; use crate::transaction::lowering::new_commit_request; use crate::Error; use crate::Key; @@ -206,6 +210,196 @@ mod test { assert_eq!(invoking_count.load(std::sync::atomic::Ordering::SeqCst), 4); } + #[tokio::test] + async fn test_region_store_mapping_retry() { + #[derive(Debug, Clone)] + struct MockOkResponse; + + impl HasKeyErrors for MockOkResponse { + fn key_errors(&mut self) -> Option> { + None + } + } + + impl HasRegionError for MockOkResponse { + fn region_error(&mut self) -> Option { + None + } + } + + impl HasLocks for MockOkResponse {} + + struct FlakyStoreMappingPdClient { + client: MockKvClient, + invalidated: AtomicBool, + invalidation_count: AtomicUsize, + } + + impl FlakyStoreMappingPdClient { + fn region(store_id: StoreId) -> RegionWithLeader { + let mut region = RegionWithLeader::default(); + region.region.id = 1; + region.region.start_key = vec![]; + region.region.end_key = vec![]; + region.region.region_epoch = Some(RegionEpoch { + conf_ver: 0, + version: 0, + }); + region.leader = Some(metapb::Peer { + store_id, + ..Default::default() + }); + region + } + } + + #[async_trait] + impl crate::pd::PdClient for FlakyStoreMappingPdClient { + type KvClient = MockKvClient; + + async fn map_region_to_store( + self: Arc, + region: RegionWithLeader, + ) -> Result { + match region.get_store_id()? { + 41 => Err(Error::InternalError { + message: "invalid store ID 41, not found".to_owned(), + }), + _ => Ok(RegionStore::new(region, Arc::new(self.client.clone()))), + } + } + + async fn region_for_key(&self, _: &Key) -> Result { + let store_id = if self.invalidated.load(Ordering::SeqCst) { + 42 + } else { + 41 + }; + Ok(Self::region(store_id)) + } + + async fn region_for_id(&self, id: RegionId) -> Result { + match id { + 1 => self.region_for_key(&Key::EMPTY).await, + _ => Err(Error::RegionNotFoundInResponse { region_id: id }), + } + } + + async fn all_stores(&self) -> Result> { + Ok(vec![Store::new(Arc::new(self.client.clone()))]) + } + + async fn get_timestamp(self: Arc) -> Result { + Ok(Timestamp::default()) + } + + async fn update_safepoint(self: Arc, _safepoint: u64) -> Result { + unimplemented!() + } + + async fn load_keyspace(&self, _keyspace: &str) -> Result { + unimplemented!() + } + + async fn update_leader( + &self, + _ver_id: RegionVerId, + _leader: metapb::Peer, + ) -> Result<()> { + Ok(()) + } + + async fn invalidate_region_cache(&self, _ver_id: RegionVerId) { + self.invalidated.store(true, Ordering::SeqCst); + self.invalidation_count.fetch_add(1, Ordering::SeqCst); + } + + async fn invalidate_store_cache(&self, _store_id: StoreId) {} + } + + #[derive(Clone)] + struct MockKvRequest { + shard_invoking_count: Arc, + } + + #[async_trait] + impl Request for MockKvRequest { + async fn dispatch(&self, _: &TikvClient, _: Duration) -> Result> { + Ok(Box::new(MockOkResponse)) + } + + fn label(&self) -> &'static str { + "mock" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn set_leader(&mut self, _: &RegionWithLeader) -> Result<()> { + Ok(()) + } + + fn set_api_version(&mut self, _: kvrpcpb::ApiVersion) {} + } + + #[async_trait] + impl KvRequest for MockKvRequest { + type Response = MockOkResponse; + } + + impl Shardable for MockKvRequest { + type Shard = Vec>; + + fn shards( + &self, + pd_client: &Arc, + ) -> futures::stream::BoxStream< + 'static, + crate::Result<(Self::Shard, crate::region::RegionWithLeader)>, + > { + self.shard_invoking_count.fetch_add(1, Ordering::SeqCst); + region_stream_for_keys( + Some(Key::from("mock_key".to_owned())).into_iter(), + pd_client.clone(), + ) + } + + fn apply_shard(&mut self, _shard: Self::Shard) {} + + fn apply_store(&mut self, _store: &crate::store::RegionStore) -> crate::Result<()> { + Ok(()) + } + } + + let dispatch_count = Arc::new(AtomicUsize::new(0)); + let shard_invoking_count = Arc::new(AtomicUsize::new(0)); + let dispatch_count_clone = dispatch_count.clone(); + + let pd_client = Arc::new(FlakyStoreMappingPdClient { + client: MockKvClient::with_dispatch_hook(move |_: &dyn Any| { + dispatch_count_clone.fetch_add(1, Ordering::SeqCst); + Ok(Box::new(MockOkResponse) as Box) + }), + invalidated: AtomicBool::new(false), + invalidation_count: AtomicUsize::new(0), + }); + + let request = MockKvRequest { + shard_invoking_count: shard_invoking_count.clone(), + }; + + let plan = crate::request::PlanBuilder::new(pd_client.clone(), Keyspace::Disable, request) + .retry_multi_region(Backoff::no_jitter_backoff(1, 1, 3)) + .plan(); + + let response = plan.execute().await; + assert!(response.is_ok()); + assert_eq!(dispatch_count.load(Ordering::SeqCst), 1); + assert_eq!(shard_invoking_count.load(Ordering::SeqCst), 2); + assert_eq!(pd_client.invalidation_count.load(Ordering::SeqCst), 1); + } + #[tokio::test] async fn test_extract_error() { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( diff --git a/src/request/plan.rs b/src/request/plan.rs index 8bd15bb5..c8acb5da 100644 --- a/src/request/plan.rs +++ b/src/request/plan.rs @@ -163,6 +163,8 @@ where preserve_region_results: bool, ) -> Result<::Result> { debug!("single_shard_handler"); + let region_ver_id = region.ver_id(); + let store_id = region.get_store_id().ok(); let region_store = match pd_client .clone() .map_region_to_store(region) @@ -172,27 +174,20 @@ where Ok(region_store) }) { Ok(region_store) => region_store, - Err(Error::LeaderNotFound { region }) => { - debug!( - "single_shard_handler::sharding: leader not found: {:?}", - region - ); + Err(err) => { + debug!("single_shard_handler::sharding, error: {:?}", err); return Self::handle_other_error( pd_client, plan, - region.clone(), - None, + region_ver_id, + store_id, backoff, permits, preserve_region_results, - Error::LeaderNotFound { region }, + err, ) .await; } - Err(err) => { - debug!("single_shard_handler::sharding, error: {:?}", err); - return Err(err); - } }; // limit concurrent requests diff --git a/src/store/client.rs b/src/store/client.rs index 1c873285..5a3163f8 100644 --- a/src/store/client.rs +++ b/src/store/client.rs @@ -20,6 +20,10 @@ pub trait KvConnect: Sized + Send + Sync + 'static { type KvClient: KvClient + Clone + Send + Sync + 'static; async fn connect(&self, address: &str) -> Result; + + fn connection_cache_key(&self) -> Result> { + Ok(None) + } } #[derive(new, Clone)] @@ -43,6 +47,10 @@ impl KvConnect for TikvConnect { .await .map(|c| KvRpcClient::new(c, self.timeout)) } + + fn connection_cache_key(&self) -> Result> { + self.security_mgr.connection_cache_key() + } } #[async_trait]