@@ -3,6 +3,7 @@ use crate::{
33 handle:: Handle ,
44 service:: { MakeService , SendService } ,
55} ;
6+ use either:: Either ;
67use http:: Request ;
78use hyper:: body:: Incoming ;
89use hyper_util:: {
@@ -28,6 +29,7 @@ pub struct Server<A = DefaultAcceptor> {
2829 builder : Builder < TokioExecutor > ,
2930 listener : Listener ,
3031 handle : Handle ,
32+ http_version : Option < HttpVersion > ,
3133}
3234
3335// Builder doesn't implement Debug or Clone right now
@@ -72,6 +74,7 @@ impl Server {
7274 builder,
7375 listener : Listener :: Bind ( addr) ,
7476 handle,
77+ http_version : None ,
7578 }
7679 }
7780
@@ -86,10 +89,17 @@ impl Server {
8689 builder,
8790 listener : Listener :: Std ( listener) ,
8891 handle,
92+ http_version : None ,
8993 }
9094 }
9195}
9296
97+ #[ derive( Clone , Copy , Eq , PartialEq ) ]
98+ enum HttpVersion {
99+ Http1 ,
100+ Http2 ,
101+ }
102+
93103impl < A > Server < A > {
94104 /// Overwrite acceptor.
95105 pub fn acceptor < Acceptor > ( self , acceptor : Acceptor ) -> Server < Acceptor > {
@@ -98,6 +108,7 @@ impl<A> Server<A> {
98108 builder : self . builder ,
99109 listener : self . listener ,
100110 handle : self . handle ,
111+ http_version : None ,
101112 }
102113 }
103114
@@ -111,6 +122,7 @@ impl<A> Server<A> {
111122 builder : self . builder ,
112123 listener : self . listener ,
113124 handle : self . handle ,
125+ http_version : None ,
114126 }
115127 }
116128
@@ -129,6 +141,20 @@ impl<A> Server<A> {
129141 & mut self . builder
130142 }
131143
144+ /// Only accepts HTTP/1
145+ pub fn http1_only ( mut self ) -> Self {
146+ self . http_version = Some ( HttpVersion :: Http1 ) ;
147+ self . builder = self . builder . http1_only ( ) ;
148+ self
149+ }
150+
151+ /// Only accepts HTTP/2
152+ pub fn http2_only ( mut self ) -> Self {
153+ self . http_version = Some ( HttpVersion :: Http2 ) ;
154+ self . builder = self . builder . http2_only ( ) ;
155+ self
156+ }
157+
132158 /// Provide a handle for additional utilities.
133159 pub fn handle ( mut self , handle : Handle ) -> Self {
134160 self . handle = handle;
@@ -192,20 +218,26 @@ impl<A> Server<A> {
192218 let acceptor = acceptor. clone ( ) ;
193219 let watcher = handle. watcher ( ) ;
194220 let builder = builder. clone ( ) ;
221+ let http_version = self . http_version ;
195222
196223 tokio:: spawn ( async move {
197224 if let Ok ( ( stream, send_service) ) = acceptor. accept ( tcp_stream, service) . await {
198225 let io = TokioIo :: new ( stream) ;
199226 let service = send_service. into_service ( ) ;
200227 let service = TowerToHyperService :: new ( service) ;
201-
202- let serve_future = builder. serve_connection_with_upgrades ( io, service) ;
228+ let serve_future = match http_version {
229+ Some ( _) => Either :: Left ( builder. serve_connection ( io, service) ) ,
230+ _ => Either :: Right ( builder. serve_connection_with_upgrades ( io, service) ) ,
231+ } ;
203232 tokio:: pin!( serve_future) ;
204-
233+ let mut serve_future = serve_future . as_pin_mut ( ) ;
205234 tokio:: select! {
206235 biased;
207236 _ = watcher. wait_graceful_shutdown( ) => {
208- serve_future. as_mut( ) . graceful_shutdown( ) ;
237+ match & mut serve_future {
238+ Either :: Left ( serve_future) => serve_future. as_mut( ) . graceful_shutdown( ) ,
239+ Either :: Right ( serve_future) => serve_future. as_mut( ) . graceful_shutdown( ) ,
240+ }
209241 tokio:: select! {
210242 biased;
211243 _ = watcher. wait_shutdown( ) => ( ) ,
@@ -270,7 +302,10 @@ pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {
270302
271303#[ cfg( test) ]
272304mod tests {
273- use crate :: { handle:: Handle , server:: Server } ;
305+ use crate :: {
306+ handle:: Handle ,
307+ server:: { HttpVersion , Server } ,
308+ } ;
274309 use axum:: body:: Body ;
275310 use axum:: response:: Response ;
276311 use axum:: routing:: post;
@@ -280,9 +315,10 @@ mod tests {
280315 use http:: { Method , Request , Uri } ;
281316 use http_body:: Frame ;
282317 use http_body_util:: { BodyExt , StreamBody } ;
283- use hyper:: client:: conn:: http1:: handshake;
284- use hyper:: client:: conn:: http1:: SendRequest ;
285- use hyper_util:: rt:: TokioIo ;
318+ use hyper:: client;
319+ use hyper:: client:: conn:: http1;
320+ use hyper:: client:: conn:: http2;
321+ use hyper_util:: rt:: { TokioExecutor , TokioIo } ;
286322 use std:: { io, net:: SocketAddr , time:: Duration } ;
287323 use tokio:: sync:: oneshot;
288324 use tokio:: { net:: TcpStream , task:: JoinHandle , time:: timeout} ;
@@ -295,7 +331,7 @@ mod tests {
295331
296332 // Client can send requests
297333
298- do_empty_request ( & mut client) . await . unwrap ( ) ;
334+ do_empty_request_h1 ( & mut client) . await . unwrap ( ) ;
299335
300336 do_slow_request ( & mut client, Duration :: from_millis ( 50 ) )
301337 . await
@@ -309,12 +345,12 @@ mod tests {
309345 let ( mut client, conn) = connect ( addr) . await ;
310346
311347 // Client can send request before shutdown.
312- do_empty_request ( & mut client) . await . unwrap ( ) ;
348+ do_empty_request_h1 ( & mut client) . await . unwrap ( ) ;
313349
314350 handle. shutdown ( ) ;
315351
316352 // After shutdown, all client requests should fail.
317- do_empty_request ( & mut client) . await . unwrap_err ( ) ;
353+ do_empty_request_h1 ( & mut client) . await . unwrap_err ( ) ;
318354
319355 // Connection should finish soon.
320356 let _ = timeout ( Duration :: from_secs ( 1 ) , conn) . await . unwrap ( ) ;
@@ -329,8 +365,8 @@ mod tests {
329365 let ( mut client2, _conn2) = connect ( addr) . await ;
330366
331367 // Clients can send request before graceful shutdown.
332- do_empty_request ( & mut client1) . await . unwrap ( ) ;
333- do_empty_request ( & mut client2) . await . unwrap ( ) ;
368+ do_empty_request_h1 ( & mut client1) . await . unwrap ( ) ;
369+ do_empty_request_h1 ( & mut client2) . await . unwrap ( ) ;
334370
335371 let start = tokio:: time:: Instant :: now ( ) ;
336372
@@ -355,9 +391,9 @@ mod tests {
355391 handle. graceful_shutdown ( None ) ;
356392
357393 // Any new requests after graceful shutdown begins will fail
358- do_empty_request ( & mut client2) . await . unwrap_err ( ) ;
359- do_empty_request ( & mut client2) . await . unwrap_err ( ) ;
360- do_empty_request ( & mut client2) . await . unwrap_err ( ) ;
394+ do_empty_request_h1 ( & mut client2) . await . unwrap_err ( ) ;
395+ do_empty_request_h1 ( & mut client2) . await . unwrap_err ( ) ;
396+ do_empty_request_h1 ( & mut client2) . await . unwrap_err ( ) ;
361397 } ;
362398
363399 tokio:: join!( fut1, fut2) ;
@@ -384,8 +420,8 @@ mod tests {
384420 let ( mut client2, _conn2) = connect ( addr) . await ;
385421
386422 // Clients can send request before graceful shutdown.
387- do_empty_request ( & mut client1) . await . unwrap ( ) ;
388- do_empty_request ( & mut client2) . await . unwrap ( ) ;
423+ do_empty_request_h1 ( & mut client1) . await . unwrap ( ) ;
424+ do_empty_request_h1 ( & mut client2) . await . unwrap ( ) ;
389425
390426 let start = tokio:: time:: Instant :: now ( ) ;
391427
@@ -430,7 +466,43 @@ mod tests {
430466 tokio:: join!( task1, task2, task3) ;
431467 }
432468
433- async fn start_server ( ) -> ( Handle , JoinHandle < io:: Result < ( ) > > , SocketAddr ) {
469+ #[ tokio:: test]
470+ async fn test_http1_only ( ) {
471+ let ( _handle, _server_task, addr) =
472+ start_server_with_http_version ( Some ( HttpVersion :: Http1 ) ) . await ;
473+
474+ let ( mut client, _conn) = connect_h1 ( addr) . await ;
475+
476+ do_empty_request_h1 ( & mut client) . await . unwrap ( ) ;
477+
478+ do_slow_request ( & mut client, Duration :: from_millis ( 50 ) )
479+ . await
480+ . unwrap ( ) ;
481+
482+ let ( mut client, _conn) = connect_h2 ( addr) . await ;
483+ do_empty_request_h2 ( & mut client) . await . unwrap_err ( ) ;
484+ }
485+
486+ #[ tokio:: test]
487+ async fn test_http2_only ( ) {
488+ let ( _handle, _server_task, addr) =
489+ start_server_with_http_version ( Some ( HttpVersion :: Http2 ) ) . await ;
490+
491+ let ( mut client, _conn) = connect_h2 ( addr) . await ;
492+
493+ do_empty_request_h2 ( & mut client) . await . unwrap ( ) ;
494+
495+ do_slow_request_h2 ( & mut client, Duration :: from_millis ( 50 ) )
496+ . await
497+ . unwrap ( ) ;
498+
499+ let ( mut client, _conn) = connect_h1 ( addr) . await ;
500+ do_empty_request_h1 ( & mut client) . await . unwrap_err ( ) ;
501+ }
502+
503+ async fn start_server_with_http_version (
504+ http_version : Option < HttpVersion > ,
505+ ) -> ( Handle , JoinHandle < io:: Result < ( ) > > , SocketAddr ) {
434506 let handle = Handle :: new ( ) ;
435507
436508 let server_handle = handle. clone ( ) ;
@@ -446,8 +518,14 @@ mod tests {
446518 ) ;
447519
448520 let addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , 0 ) ) ;
449-
450- Server :: bind ( addr)
521+ let server = Server :: bind ( addr) ;
522+ let server = match http_version {
523+ Some ( HttpVersion :: Http1 ) => server. http1_only ( ) ,
524+ Some ( HttpVersion :: Http2 ) => server. http2_only ( ) ,
525+ None => server,
526+ } ;
527+
528+ server
451529 . handle ( server_handle)
452530 . serve ( app. into_make_service ( ) )
453531 . await
@@ -458,9 +536,17 @@ mod tests {
458536 ( handle, server_task, addr)
459537 }
460538
461- async fn connect ( addr : SocketAddr ) -> ( SendRequest < Body > , JoinHandle < ( ) > ) {
539+ async fn start_server ( ) -> ( Handle , JoinHandle < io:: Result < ( ) > > , SocketAddr ) {
540+ start_server_with_http_version ( None ) . await
541+ }
542+
543+ async fn connect ( addr : SocketAddr ) -> ( http1:: SendRequest < Body > , JoinHandle < ( ) > ) {
544+ connect_h1 ( addr) . await
545+ }
546+
547+ async fn connect_h1 ( addr : SocketAddr ) -> ( http1:: SendRequest < Body > , JoinHandle < ( ) > ) {
462548 let stream = TokioIo :: new ( TcpStream :: connect ( addr) . await . unwrap ( ) ) ;
463- let ( send_request, connection) = handshake ( stream) . await . unwrap ( ) ;
549+ let ( send_request, connection) = client :: conn :: http1 :: handshake ( stream) . await . unwrap ( ) ;
464550
465551 let task = tokio:: spawn ( async move {
466552 let _ = connection. await ;
@@ -469,8 +555,36 @@ mod tests {
469555 ( send_request, task)
470556 }
471557
558+ async fn connect_h2 ( addr : SocketAddr ) -> ( http2:: SendRequest < Body > , JoinHandle < ( ) > ) {
559+ let stream = TokioIo :: new ( TcpStream :: connect ( addr) . await . unwrap ( ) ) ;
560+ let ( send_request, connection) =
561+ client:: conn:: http2:: handshake ( TokioExecutor :: new ( ) , stream)
562+ . await
563+ . unwrap ( ) ;
564+
565+ let task = tokio:: spawn ( async move {
566+ let _ = connection. await ;
567+ } ) ;
568+
569+ ( send_request, task)
570+ }
571+
572+ // Send a basic `GET /` request.
573+ async fn do_empty_request_h1 ( client : & mut http1:: SendRequest < Body > ) -> hyper:: Result < ( ) > {
574+ client. ready ( ) . await ?;
575+
576+ let body = client
577+ . send_request ( Request :: new ( Body :: empty ( ) ) )
578+ . await ?
579+ . into_body ( ) ;
580+
581+ let body = body. collect ( ) . await ?. to_bytes ( ) ;
582+ assert_eq ! ( body. as_ref( ) , b"Hello, world!" ) ;
583+ Ok ( ( ) )
584+ }
585+
472586 // Send a basic `GET /` request.
473- async fn do_empty_request ( client : & mut SendRequest < Body > ) -> hyper:: Result < ( ) > {
587+ async fn do_empty_request_h2 ( client : & mut http2 :: SendRequest < Body > ) -> hyper:: Result < ( ) > {
474588 client. ready ( ) . await ?;
475589
476590 let body = client
@@ -486,15 +600,36 @@ mod tests {
486600 // Send a request with a body streamed byte-by-byte, over a given duration,
487601 // then wait for the full response.
488602 async fn do_slow_request (
489- client : & mut SendRequest < Body > ,
603+ client : & mut http1 :: SendRequest < Body > ,
490604 duration : Duration ,
491605 ) -> hyper:: Result < ( ) > {
492606 let response = send_slow_request ( client, duration) . await ?;
493607 recv_slow_response_body ( response) . await
494608 }
495609
610+ async fn do_slow_request_h2 (
611+ client : & mut http2:: SendRequest < Body > ,
612+ duration : Duration ,
613+ ) -> hyper:: Result < ( ) > {
614+ let response = send_slow_request_h2 ( client, duration) . await ?;
615+ recv_slow_response_body ( response) . await
616+ }
617+
496618 async fn send_slow_request (
497- client : & mut SendRequest < Body > ,
619+ client : & mut http1:: SendRequest < Body > ,
620+ duration : Duration ,
621+ ) -> hyper:: Result < http:: Response < hyper:: body:: Incoming > > {
622+ let req_body_len: usize = 10 ;
623+ let mut req = Request :: new ( slow_body ( req_body_len, duration) ) ;
624+ * req. method_mut ( ) = Method :: POST ;
625+ * req. uri_mut ( ) = Uri :: from_static ( "/echo_slowly" ) ;
626+
627+ client. ready ( ) . await ?;
628+ client. send_request ( req) . await
629+ }
630+
631+ async fn send_slow_request_h2 (
632+ client : & mut http2:: SendRequest < Body > ,
498633 duration : Duration ,
499634 ) -> hyper:: Result < http:: Response < hyper:: body:: Incoming > > {
500635 let req_body_len: usize = 10 ;
0 commit comments