Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 106 additions & 17 deletions src/common/security.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
// Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0.

use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::fs::File;
use std::hash::Hash;
use std::hash::Hasher;
use std::io::Read;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
use std::time::SystemTime;

use log::info;
use regex::Regex;
Expand Down Expand Up @@ -43,12 +48,12 @@ fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
/// Manages the TLS protocol
#[derive(Default)]
pub struct SecurityManager {
/// The PEM encoding of the server’s CA certificates.
ca: Vec<u8>,
/// The PEM encoding of the server’s certificate chain.
cert: Vec<u8>,
/// The path to the PEM encoding of the server’s CA certificates.
ca_path: Option<PathBuf>,
/// The path to the PEM encoding of the server’s certificate chain.
cert_path: Option<PathBuf>,
/// The path to the file that contains the PEM encoding of the server’s private key.
key: PathBuf,
key_path: Option<PathBuf>,
}

impl SecurityManager {
Expand All @@ -58,15 +63,35 @@ impl SecurityManager {
cert_path: impl AsRef<Path>,
key_path: impl Into<PathBuf>,
) -> Result<SecurityManager> {
let ca_path = ca_path.as_ref().to_path_buf();
let cert_path = cert_path.as_ref().to_path_buf();
let key_path = key_path.into();
check_pem_file("ca", &ca_path)?;
check_pem_file("certificate", &cert_path)?;
check_pem_file("private key", &key_path)?;
Ok(SecurityManager {
ca: load_pem_file("ca", ca_path.as_ref())?,
cert: load_pem_file("certificate", cert_path.as_ref())?,
key: key_path,
ca_path: Some(ca_path),
cert_path: Some(cert_path),
key_path: Some(key_path),
})
}

pub(crate) fn tls_configured(&self) -> bool {
self.ca_path.is_some()
}

pub(crate) fn connection_cache_key(&self) -> Result<Option<u64>> {
if !self.tls_configured() {
return Ok(None);
}

let mut hasher = DefaultHasher::new();
file_signature(self.ca_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher);
file_signature(self.cert_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher);
file_signature(self.key_path.as_ref().expect("tls_configured checked"))?.hash(&mut hasher);
Ok(Some(hasher.finish()))
}

/// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection.
pub async fn connect<Factory, Client>(
&self,
Expand All @@ -78,7 +103,7 @@ impl SecurityManager {
Factory: FnOnce(Channel) -> Client,
{
info!("connect to rpc server at endpoint: {:?}", addr);
let channel = if !self.ca.is_empty() {
let channel = if self.tls_configured() {
self.tls_channel(addr).await?
} else {
self.default_channel(addr).await?
Expand All @@ -89,18 +114,37 @@ impl SecurityManager {
}

async fn tls_channel(&self, addr: &str) -> Result<Endpoint> {
let (ca, cert, key) = self.load_tls_materials()?;
let addr = "https://".to_string() + &SCHEME_REG.replace(addr, "");
let builder = self.endpoint(addr.to_string())?;
let tls = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(&self.ca))
.identity(Identity::from_pem(
&self.cert,
load_pem_file("private key", &self.key)?,
));
.ca_certificate(Certificate::from_pem(ca))
.identity(Identity::from_pem(cert, key));
let builder = builder.tls_config(tls)?;
Ok(builder)
}

fn load_tls_materials(&self) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let ca_path = self
.ca_path
.as_ref()
.ok_or_else(|| internal_err!("TLS is not configured"))?;
let cert_path = self
.cert_path
.as_ref()
.ok_or_else(|| internal_err!("TLS is not configured"))?;
let key_path = self
.key_path
.as_ref()
.ok_or_else(|| internal_err!("TLS is not configured"))?;

Ok((
load_pem_file("ca", ca_path)?,
load_pem_file("certificate", cert_path)?,
load_pem_file("private key", key_path)?,
))
}

async fn default_channel(&self, addr: &str) -> Result<Endpoint> {
let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
self.endpoint(addr)
Expand All @@ -114,6 +158,17 @@ impl SecurityManager {
}
}

fn file_signature(path: &Path) -> Result<(u64, Option<u128>)> {
let metadata = fs::metadata(path)
.map_err(|e| internal_err!("failed to stat {}: {:?}", path.display(), e))?;
let modified = metadata.modified().ok().and_then(|t: SystemTime| {
t.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.map(|d| d.as_nanos())
});
Ok((metadata.len(), modified))
}

#[cfg(test)]
mod tests {
use std::fs::File;
Expand All @@ -140,9 +195,43 @@ mod tests {
let key_path: PathBuf = format!("{}", example_pem.display()).into();
let ca_path: PathBuf = format!("{}", example_ca.display()).into();
let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
assert_eq!(mgr.ca, vec![0]);
assert_eq!(mgr.cert, vec![1]);
let key = load_pem_file("private key", &key_path).unwrap();
assert!(mgr.tls_configured());
let (ca, cert, key) = mgr.load_tls_materials().unwrap();
assert_eq!(ca, vec![0]);
assert_eq!(cert, vec![1]);
assert_eq!(key, vec![2]);
}

#[test]
fn test_security_reload() {
let temp = tempfile::tempdir().unwrap();
let example_ca = temp.path().join("ca");
let example_cert = temp.path().join("cert");
let example_pem = temp.path().join("key");
for (id, f) in [&example_ca, &example_cert, &example_pem]
.iter()
.enumerate()
{
File::create(f).unwrap().write_all(&[id as u8]).unwrap();
}

let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap();
let first = mgr.load_tls_materials().unwrap();
let key1 = mgr.connection_cache_key().unwrap();

File::create(&example_ca).unwrap().write_all(&[9]).unwrap();
File::create(&example_cert)
.unwrap()
.write_all(&[8])
.unwrap();
File::create(&example_pem).unwrap().write_all(&[7]).unwrap();

let second = mgr.load_tls_materials().unwrap();
let key2 = mgr.connection_cache_key().unwrap();
assert_ne!(first, second);
assert_eq!(second.0, vec![9]);
assert_eq!(second.1, vec![8]);
assert_eq!(second.2, vec![7]);
assert_ne!(key1, key2);
}
}
136 changes: 129 additions & 7 deletions src/pd/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,17 @@ pub trait PdClient: Send + Sync + 'static {
pub struct PdRpcClient<KvC: KvConnect + Send + Sync + 'static = TikvConnect, Cl = Cluster> {
pd: Arc<RetryClient<Cl>>,
kv_connect: KvC,
kv_client_cache: Arc<RwLock<HashMap<String, KvC::KvClient>>>,
kv_client_cache: Arc<RwLock<HashMap<String, CachedKvClient<KvC::KvClient>>>>,
enable_codec: bool,
region_cache: RegionCache<RetryClient<Cl>>,
}

#[derive(Clone)]
struct CachedKvClient<Client> {
cache_key: Option<u64>,
client: Client,
}

#[async_trait]
impl<KvC: KvConnect + Send + Sync + 'static> PdClient for PdRpcClient<KvC> {
type KvClient = KvC::KvClient;
Expand Down Expand Up @@ -338,16 +344,22 @@ impl<KvC: KvConnect + Send + Sync + 'static, Cl> PdRpcClient<KvC, Cl> {
}

async fn kv_client(&self, address: &str) -> Result<KvC::KvClient> {
if let Some(client) = self.kv_client_cache.read().await.get(address) {
return Ok(client.clone());
let cache_key = self.kv_connect.connection_cache_key()?;
if let Some(cached) = self.kv_client_cache.read().await.get(address) {
if cached.cache_key == cache_key {
return Ok(cached.client.clone());
}
};
info!("connect to tikv endpoint: {:?}", address);
match self.kv_connect.connect(address).await {
Ok(client) => {
self.kv_client_cache
.write()
.await
.insert(address.to_owned(), client.clone());
self.kv_client_cache.write().await.insert(
address.to_owned(),
CachedKvClient {
cache_key,
client: client.clone(),
},
);
Ok(client)
}
Err(e) => Err(e),
Expand All @@ -364,11 +376,18 @@ fn make_key_range(start_key: Vec<u8>, end_key: Vec<u8>) -> kvrpcpb::KeyRange {

#[cfg(test)]
pub mod test {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use async_trait::async_trait;
use futures::executor;
use futures::executor::block_on;

use super::*;
use crate::mock::*;
use crate::pd::RetryClient;
use crate::store::KvConnect;
use crate::Config;

#[tokio::test]
async fn test_kv_client_caching() {
Expand All @@ -384,6 +403,109 @@ pub mod test {
assert_eq!(kv2.addr, kv3.addr);
}

#[tokio::test]
async fn test_kv_client_cache_hits_when_key_is_stable() {
#[derive(Clone)]
struct CountingConnect {
connects: Arc<AtomicUsize>,
}

#[async_trait]
impl KvConnect for CountingConnect {
type KvClient = MockKvClient;

async fn connect(&self, address: &str) -> Result<Self::KvClient> {
self.connects.fetch_add(1, Ordering::SeqCst);
let mut client = MockKvClient::default();
client.addr = address.to_owned();
Ok(client)
}

fn connection_cache_key(&self) -> Result<Option<u64>> {
Ok(Some(0))
}
}

let connects = Arc::new(AtomicUsize::new(0));
let connects_clone = connects.clone();
let client = PdRpcClient::new(
Config::default(),
move |_| CountingConnect {
connects: connects_clone.clone(),
},
|sm| async move {
Ok(RetryClient::new_with_cluster(
sm,
Config::default().timeout,
MockCluster,
))
},
false,
)
.await
.unwrap();

let kv1 = client.kv_client("foo").await.unwrap();
let kv2 = client.kv_client("foo").await.unwrap();
assert_eq!(kv1.addr, "foo");
assert_eq!(kv2.addr, "foo");
assert_eq!(connects.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn test_kv_client_cache_invalidate_on_key_change() {
#[derive(Clone)]
struct CountingConnect {
connects: Arc<AtomicUsize>,
cache_key: Arc<AtomicUsize>,
}

#[async_trait]
impl KvConnect for CountingConnect {
type KvClient = MockKvClient;

async fn connect(&self, address: &str) -> Result<Self::KvClient> {
self.connects.fetch_add(1, Ordering::SeqCst);
let mut client = MockKvClient::default();
client.addr = address.to_owned();
Ok(client)
}

fn connection_cache_key(&self) -> Result<Option<u64>> {
Ok(Some(self.cache_key.load(Ordering::SeqCst) as u64))
}
}

let connects = Arc::new(AtomicUsize::new(0));
let cache_key = Arc::new(AtomicUsize::new(1));
let connects_clone = connects.clone();
let cache_key_clone = cache_key.clone();
let client = PdRpcClient::new(
Config::default(),
move |_| CountingConnect {
connects: connects_clone.clone(),
cache_key: cache_key_clone.clone(),
},
|sm| async move {
Ok(RetryClient::new_with_cluster(
sm,
Config::default().timeout,
MockCluster,
))
},
false,
)
.await
.unwrap();

let kv1 = client.kv_client("foo").await.unwrap();
cache_key.store(2, Ordering::SeqCst);
let kv2 = client.kv_client("foo").await.unwrap();
assert_eq!(kv1.addr, "foo");
assert_eq!(kv2.addr, "foo");
assert_eq!(connects.load(Ordering::SeqCst), 2);
}

#[test]
fn test_group_keys_by_region() {
let client = MockPdClient::default();
Expand Down
Loading
Loading