Skip to content

Commit 89ffe34

Browse files
feat: Integrate TLS Certificate to the server and create a generic trait for tcp stream
1 parent df49752 commit 89ffe34

11 files changed

Lines changed: 440 additions & 40 deletions

File tree

Cargo.lock

Lines changed: 316 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

client/src/handlers/client.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ use crate::{
2929
types::{LogLevel, LogMessage},
3030
utils::perform_handshake,
3131
};
32+
use common::net::AsyncStream;
3233
use std::sync::Arc;
33-
use tokio::{net::TcpStream, sync::Mutex};
34+
use tokio::sync::Mutex;
3435

3536
/// Handles an incoming client connection.
3637
///
@@ -66,8 +67,8 @@ use tokio::{net::TcpStream, sync::Mutex};
6667
/// [`TcpStream`]: tokio::net::TcpStream
6768
/// [`perform_handshake`]: crate::utils::perform_handshake
6869
/// [`LogMessage`]: crate::types::LogMessage
69-
pub async fn handle_client(stream: TcpStream) {
70-
let (rd, wt) = stream.into_split();
70+
pub async fn handle_client(stream: Box<dyn AsyncStream>) {
71+
let (rd, wt) = tokio::io::split(stream);
7172
let rd = Arc::new(Mutex::new(rd));
7273
let wt = Arc::new(Mutex::new(wt));
7374

client/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async fn main() {
3939
)
4040
.await;
4141

42-
handle_client(stream).await;
42+
handle_client(Box::new(stream)).await;
4343
}
4444
Err(_) => {
4545
LogMessage::log(LogLevel::ERROR, format!("Failed to connect to {}", addr), 0).await;

common/src/net.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Arc;
22

33
use bincode::{Decode, Encode};
44
use tokio::{
5-
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
5+
io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf},
66
sync::Mutex,
77
};
88

@@ -46,7 +46,14 @@ pub struct HandshakePacket {
4646
pub session_key: Option<Vec<u8>>,
4747
}
4848

49+
/// Custom trait that bundles AsyncRead + AsyncWrite
50+
pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
51+
impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncStream for T {}
52+
4953
/// Represents a stream reader for a client
50-
pub type StreamReader = Arc<Mutex<OwnedReadHalf>>;
54+
// pub type StreamReader<S> = Arc<Mutex<OwnedReadHalf<S>>>;
55+
pub type StreamReader = Arc<Mutex<ReadHalf<Box<dyn AsyncStream>>>>;
5156
/// Represents a stream writer for a client
52-
pub type StreamWriter = Arc<Mutex<OwnedWriteHalf>>;
57+
// pub type StreamWriter = Arc<Mutex<OwnedWriteHalf>>;
58+
pub type StreamWriter = Arc<Mutex<WriteHalf<Box<dyn AsyncStream>>>>;
59+

server/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ common = { path = "../common" }
1212
config.workspace = true
1313
hex.workspace = true
1414
rsa.workspace = true
15+
rustls = "0.23.31"
16+
rustls-pemfile = "2.2.0"
1517
serde = { workspace = true, features = ["derive"] }
1618
tokio = { workspace = true, features = ["full"] }
19+
tokio-rustls = "0.26.2"
1720
uuid = { version = "1.18.0", features = ["v4"] }

server/Config.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
host = "0.0.0.0"
22
port = 8080
3+
4+
[tls]
5+
enabled = false
6+
cert_path = "/etc/letsencrypt/live/viveksahani.com/fullchain.pem"
7+
key_path = "/etc/letsencrypt/live/viveksahani.com/privkey.pem"

server/src/config.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@ use config::{Config, File};
22
use serde::Deserialize;
33
use std::{env, error::Error, path::PathBuf};
44

5+
#[derive(Debug, Deserialize, Clone)]
6+
pub struct TLSConfig {
7+
pub enabled: bool,
8+
pub cert_path: Option<String>,
9+
pub key_path: Option<String>,
10+
}
11+
512
/// Configuration for the server
613
#[derive(Debug, Deserialize, Clone)]
714
pub struct ServerConfig {
@@ -10,6 +17,9 @@ pub struct ServerConfig {
1017

1118
/// Port number for the server
1219
pub port: u16,
20+
21+
/// TLS configuration
22+
pub tls: Option<TLSConfig>
1323
}
1424

1525
impl ServerConfig {

server/src/handlers/client.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
use std::sync::Arc;
22

3-
use tokio::{
4-
net::TcpStream,
5-
sync::{Mutex as AsyncMutex, mpsc::UnboundedSender},
6-
};
3+
use tokio::sync::{Mutex as AsyncMutex, mpsc::UnboundedSender};
74

85
use crate::{
9-
data::CLIENTS, handlers::task::start_reader_task, types::Client, net::perform_handshake,
6+
data::CLIENTS, handlers::task::start_reader_task, net::perform_handshake, types::Client,
7+
};
8+
use common::{
9+
net::{AsyncStream, Packet, StreamReader, StreamWriter},
10+
utils::enc::public_key_to_user_id,
1011
};
11-
use common::{net::Packet, utils::enc::public_key_to_user_id};
1212

1313
/// Handle a new client connection
14-
pub async fn handle_client(stream: TcpStream, tx: Arc<AsyncMutex<UnboundedSender<Packet>>>) {
15-
let (rd, wt) = stream.into_split();
16-
let rd = Arc::new(AsyncMutex::new(rd));
17-
let wt = Arc::new(AsyncMutex::new(wt));
14+
pub async fn handle_client(
15+
stream: Box<dyn AsyncStream>,
16+
tx: Arc<AsyncMutex<UnboundedSender<Packet>>>,
17+
) {
18+
let (rd, wt) = tokio::io::split(stream);
19+
let rd: StreamReader = Arc::new(AsyncMutex::new(rd));
20+
let wt: StreamWriter = Arc::new(AsyncMutex::new(wt));
1821

1922
let (name, session_key, public_key) = match perform_handshake(rd.clone(), wt.clone()).await {
2023
Ok(data) => data,

server/src/handlers/msg.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
use crate::data;
12
use common::{net::Packet, types::Message, utils::net::write_packet};
23

3-
use crate::data::{self, CONVERSATIONS};
4-
54
/// Handle a group message
65
/// Send the message to every active member of the group
76
pub async fn handle_group_message(packet: Packet, group_id: &str) {
@@ -35,7 +34,7 @@ pub async fn handle_group_message(packet: Packet, group_id: &str) {
3534
// Handle a direct message
3635
pub async fn handle_direct_message(packet: Packet, session_id: &str) {
3736
// Find the conversation
38-
let dm = match CONVERSATIONS.lock().await.get(session_id) {
37+
let dm = match data::CONVERSATIONS.lock().await.get(session_id) {
3938
Some(dm) => dm.clone(),
4039
None => return,
4140
};

server/src/main.rs

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
use std::sync::Arc;
2-
31
use common::net::Packet;
42
use null_talk_server::{
53
ServerConfig,
64
handlers::{handle_client, task::start_writer_task},
75
};
6+
use rustls::pki_types::PrivateKeyDer;
7+
use std::{fs, io::BufReader, sync::Arc};
88
use tokio::{
99
net::TcpListener,
1010
sync::{Mutex as AsyncMutex, mpsc},
1111
};
12+
use tokio_rustls::{TlsAcceptor, rustls::ServerConfig as RustlsServerConfig};
1213

1314
/// Main entry point for the server
1415
#[tokio::main]
@@ -23,23 +24,86 @@ async fn main() {
2324
println!("🔧 Configuration Loaded");
2425

2526
let server_address = config.get_addr();
26-
let listener = TcpListener::bind(&server_address).await.unwrap();
27-
println!("🚀 Server listening on {}", &server_address);
2827

28+
// Shared channel for communication
2929
let (tx, rx) = mpsc::unbounded_channel::<Packet>();
3030
let _ = start_writer_task(rx).await;
31-
3231
let sender: Arc<AsyncMutex<mpsc::UnboundedSender<Packet>>> = Arc::new(AsyncMutex::new(tx));
32+
33+
// TLS check
34+
if let Some(tls_cfg) = &config.tls {
35+
if tls_cfg.enabled {
36+
// Load certs
37+
let certs = {
38+
let file = fs::File::open(tls_cfg.cert_path.as_ref().unwrap())
39+
.expect("Cannot open cert file");
40+
let mut reader = BufReader::new(file);
41+
rustls_pemfile::certs(&mut reader)
42+
.into_iter()
43+
.map(|v| v.expect("Failed to read certificate"))
44+
.collect::<Vec<_>>()
45+
};
46+
47+
let key = {
48+
let file = fs::File::open(tls_cfg.key_path.as_ref().unwrap())
49+
.expect("Cannot open key file");
50+
let mut reader = BufReader::new(file);
51+
let first_key = rustls_pemfile::pkcs8_private_keys(&mut reader)
52+
.next()
53+
.ok_or("No private keys found")
54+
.expect("Failed to read private key")
55+
.expect("Failed to read private key");
56+
57+
PrivateKeyDer::Pkcs8(first_key)
58+
};
59+
60+
let tls_config = RustlsServerConfig::builder()
61+
.with_no_client_auth()
62+
.with_single_cert(certs, key)
63+
.expect("Failed to create TLS config");
64+
65+
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
66+
let listener = TcpListener::bind(&server_address).await.unwrap();
67+
println!("🔒 TLS Server listening on {}", &server_address);
68+
69+
loop {
70+
match listener.accept().await {
71+
Ok((stream, _)) => {
72+
let acceptor = acceptor.clone();
73+
let sd_clone = sender.clone();
74+
75+
tokio::spawn(async move {
76+
match acceptor.accept(stream).await {
77+
Ok(tls_stream) => {
78+
handle_client(Box::new(tls_stream), sd_clone).await
79+
}
80+
Err(e) => eprintln!("TLS handshake failed: {:?}", e),
81+
}
82+
});
83+
}
84+
Err(e) => eprintln!("Failed to accept connection: {:?}", e),
85+
}
86+
}
87+
} else {
88+
start_plain_server(&server_address, sender).await;
89+
}
90+
} else {
91+
start_plain_server(&server_address, sender).await;
92+
}
93+
}
94+
95+
/// Helper to start plain TCP server
96+
async fn start_plain_server(addr: &str, sender: Arc<AsyncMutex<mpsc::UnboundedSender<Packet>>>) {
97+
let listener = TcpListener::bind(addr).await.unwrap();
98+
println!("🚀 Plain TCP Server listening on {}", addr);
99+
33100
loop {
34101
match listener.accept().await {
35102
Ok((stream, _)) => {
36103
let sd_clone = sender.clone();
37-
tokio::spawn(async move { handle_client(stream, sd_clone).await });
104+
tokio::spawn(async move { handle_client(Box::new(stream), sd_clone).await });
38105
}
39-
Err(e) => {
40-
eprintln!("Failed to accept connection: {:?}", e);
41-
continue;
42-
}
43-
};
106+
Err(e) => eprintln!("Failed to accept connection: {:?}", e),
107+
}
44108
}
45109
}

0 commit comments

Comments
 (0)