diff --git a/src/dist/download.rs b/src/dist/download.rs index f1d25f10e8..04456692e5 100644 --- a/src/dist/download.rs +++ b/src/dist/download.rs @@ -82,7 +82,7 @@ impl<'a> DownloadCfg<'a> { let partial_file_existed = partial_file_path.exists(); let mut hasher = Sha256::new(); - let download = DownloadOptions::try_from(self.process)? + let mut download = DownloadOptions::try_from(self.process)? .start(url, &partial_file_path) .with_hasher(&mut hasher) .with_status(status) @@ -268,7 +268,7 @@ impl<'a> DownloadCfg<'a> { .start(&url, &file) .with_hasher(&mut hasher); - let download = match status { + let mut download = match status { Some(status) => download.with_status(status), None => download, }; diff --git a/src/download/mod.rs b/src/download/mod.rs index ce3c69f16c..8871a936f5 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -1,6 +1,5 @@ //! Easy file downloading -use std::cell::RefCell; use std::fs::{self, OpenOptions, remove_file}; use std::io::{self, Read, Seek, SeekFrom, Write}; use std::num::NonZero; @@ -106,7 +105,7 @@ impl TryFrom<&Process> for DownloadOptions { pub struct Download<'a> { url: &'a Url, path: &'a Path, - hasher: Option>, + hasher: Option<&'a mut Sha256>, status: Option<&'a DownloadStatus>, resume: bool, options: DownloadOptions, @@ -114,7 +113,7 @@ pub struct Download<'a> { impl<'a> Download<'a> { pub(crate) fn with_hasher(mut self, hasher: &'a mut Sha256) -> Self { - self.hasher = Some(RefCell::new(hasher)); + self.hasher = Some(hasher); self } @@ -128,110 +127,66 @@ impl<'a> Download<'a> { self } - pub(crate) async fn download(&self) -> anyhow::Result<()> { - match self.download_file_().await { - Ok(_) => Ok(()), - Err(e) => { - if e.downcast_ref::().is_some() { - return Err(e); - } - let is_client_error = match e.downcast_ref::() { - // Specifically treat the bad partial range error as not our - // fault in case it was something odd which happened. - Some(DownloadError::HttpStatus(416)) => false, - Some(DownloadError::HttpStatus(400..=499)) - | Some(DownloadError::FileNotFound) => true, - _ => false, - }; - Err(e).with_context(|| { - if is_client_error { - RustupError::DownloadNotExists { - url: self.url.clone(), - path: self.path.to_path_buf(), - } - } else { - RustupError::DownloadingFile { - url: self.url.clone(), - path: self.path.to_path_buf(), - } - } - }) - } - } - } - - async fn download_file_(&self) -> anyhow::Result<()> { + pub(crate) async fn download(&mut self) -> anyhow::Result<()> { debug!(url = %self.url, "downloading file"); - // This callback will write the download to disk and optionally - // hash the contents, then forward the notification up the stack - let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { - if let Event::DownloadDataReceived(data) = msg - && let Some(h) = self.hasher.as_ref() - { - h.borrow_mut().update(data); - } - - match msg { - Event::DownloadContentLengthReceived(len) => { - if let Some(status) = self.status { - status.received_length(len) - } - } - Event::DownloadDataReceived(data) => { - if let Some(status) = self.status { - status.received_data(data.len()) - } - } - Event::ResumingPartialDownload => debug!("resuming partial download"), + let Err(err) = self.download_impl().await else { + if let Some(status) = self.status { + status.finished(); } - - Ok(()) + return Ok(()); }; - // Download the file - - let res = self.download_to_path(Some(callback)).await; - - // The notification should only be sent if the download was successful (i.e. didn't timeout) if let Some(status) = self.status { - match &res { - Ok(_) => status.finished(), - Err(_) => status.failed(), - }; + status.failed(); } - res - } - - async fn download_to_path(&self, callback: Option>) -> anyhow::Result<()> { - let Err(err) = self.download_impl(callback).await else { - return Ok(()); - }; - // TODO: Currently, we only refrain from removing the cached download // if there was a network failure from the client side. // It may be worth looking for other cases where removal is also not desired. - Err( - if !(self.resume && is_network_failure(&err)) - && let Err(file_err) = - remove_file(self.path).context("cleaning up cached downloads") - { - file_err.context(err) + let e = if !(self.resume && is_network_failure(&err)) + && let Err(file_err) = remove_file(self.path).context("cleaning up cached downloads") + { + file_err.context(err) + } else { + err + }; + + if e.downcast_ref::().is_some() { + return Err(e); + } + + let is_client_error = match e.downcast_ref::() { + // Specifically treat the bad partial range error as not our + // fault in case it was something odd which happened. + Some(DownloadError::HttpStatus(416)) => false, + Some(DownloadError::HttpStatus(400..=499)) | Some(DownloadError::FileNotFound) => true, + _ => false, + }; + + Err(e).with_context(|| { + if is_client_error { + RustupError::DownloadNotExists { + url: self.url.clone(), + path: self.path.to_path_buf(), + } } else { - err - }, - ) + RustupError::DownloadingFile { + url: self.url.clone(), + path: self.path.to_path_buf(), + } + } + }) } - async fn download_impl(&self, callback: Option>) -> anyhow::Result<()> { - let (file, resume_from) = if self.resume { + async fn download_impl(&mut self) -> anyhow::Result<()> { + let (mut file, resume_from) = if self.resume { // TODO: blocking call let possible_partial = OpenOptions::new().read(true).open(self.path); let downloaded_so_far = if let Ok(mut partial) = possible_partial { - if let Some(cb) = callback { - cb(Event::ResumingPartialDownload)?; + if self.status.is_some() || self.hasher.is_some() { + debug!("resuming partial download"); let mut buf = vec![0; 32768]; let mut downloaded_so_far = 0; @@ -241,7 +196,7 @@ impl<'a> Download<'a> { if n == 0 { break; } - cb(Event::DownloadDataReceived(&buf[..n]))?; + self.data_received(&buf[..n]); } downloaded_so_far @@ -276,7 +231,6 @@ impl<'a> Download<'a> { ) }; - let file = RefCell::new(file); let client = match self.options.tls { #[cfg(feature = "reqwest-rustls-tls")] Tls::Rustls => rustls_client(self.options.timeout)?, @@ -285,34 +239,18 @@ impl<'a> Download<'a> { }; // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. - self.execute( - resume_from, - &|event| { - if let Event::DownloadDataReceived(data) = event { - file.borrow_mut() - .write_all(data) - .context("unable to write download to disk")?; - } - match callback { - Some(cb) => cb(event), - None => Ok(()), - } - }, - client, - ) - .await?; + self.execute(&mut file, resume_from, client).await?; - file.borrow_mut() - .sync_data() + file.sync_data() .context("unable to sync download to disk")?; Ok::<(), anyhow::Error>(()) } async fn execute( - &self, + &mut self, + file: &mut fs::File, resume_from: u64, - callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, client: &Client, ) -> anyhow::Result<()> { // Short-circuit reqwest for the "file:" URL scheme @@ -339,7 +277,10 @@ impl<'a> Download<'a> { if bytes_read == 0 { break; } - callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; + + file.write_all(&buffer[0..bytes_read]) + .context("unable to write download to disk")?; + self.data_received(&buffer[0..bytes_read]); } return Ok(()); @@ -366,16 +307,29 @@ impl<'a> Download<'a> { if let Some(len) = res.content_length() { let len = len + resume_from; - callback(Event::DownloadContentLengthReceived(len))?; + if let Some(status) = self.status { + status.received_length(len); + } } let mut stream = res.bytes_stream(); while let Some(item) = stream.next().await { let bytes = item.map_err(DownloadError::Reqwest)?; - callback(Event::DownloadDataReceived(&bytes))?; + file.write_all(&bytes) + .context("unable to write download to disk")?; + self.data_received(&bytes); } Ok(()) } + + fn data_received(&mut self, data: &[u8]) { + if let Some(hasher) = &mut self.hasher { + hasher.update(data); + } + if let Some(status) = self.status { + status.received_data(data.len()); + } + } } pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { @@ -407,17 +361,6 @@ enum Tls { NativeTls, } -#[derive(Debug, Copy, Clone)] -enum Event<'a> { - ResumingPartialDownload, - /// Received the Content-Length of the to-be downloaded data. - DownloadContentLengthReceived(u64), - /// Received some data. - DownloadDataReceived(&'a [u8]), -} - -type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>; - fn client_generic() -> ClientBuilder { Client::builder() // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying diff --git a/src/download/tests.rs b/src/download/tests.rs index 0375cf9218..a49ec7cee6 100644 --- a/src/download/tests.rs +++ b/src/download/tests.rs @@ -20,8 +20,7 @@ mod reqwest { use std::env::set_var; use std::error::Error; use std::net::TcpListener; - use std::sync::Mutex; - use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; use std::time::Duration; @@ -30,7 +29,7 @@ mod reqwest { use url::Url; use super::{scrub_env, serve_file, tmp_dir, write_file}; - use crate::download::{DownloadOptions, Event, Tls}; + use crate::download::{DownloadOptions, Tls}; const OPTIONS: DownloadOptions = DownloadOptions { tls: DOWNLOAD_BACKEND, @@ -118,61 +117,13 @@ mod reqwest { OPTIONS .start(&from_url, &target_path) .with_resume() - .download_to_path(None) + .download() .await .expect("Test download failed"); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); } - #[tokio::test] - async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() { - let _guard = scrub_env().await; - let tmpdir = tmp_dir(); - let target_path = tmpdir.path().join("downloaded"); - write_file(&target_path, "123"); - - let addr = serve_file(b"xxx45".to_vec(), true); - - let from_url = format!("http://{addr}").parse().unwrap(); - - let callback_partial = AtomicBool::new(false); - let callback_len = Mutex::new(None); - let received_in_callback = Mutex::new(Vec::new()); - - OPTIONS - .start(&from_url, &target_path) - .with_resume() - .download_to_path(Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); - } - } - } - - Ok(()) - })) - .await - .expect("Test download failed"); - - assert!(callback_partial.into_inner()); - assert_eq!(*callback_len.lock().unwrap(), Some(5)); - let observed_bytes = received_in_callback.into_inner().unwrap(); - assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']); - assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345"); - } - #[tokio::test] async fn resume_partial_fails_if_server_ignores_range() { let _guard = scrub_env().await; @@ -186,7 +137,7 @@ mod reqwest { OPTIONS .start(&from_url, &target_path) .with_resume() - .download_to_path(None) + .download() .await .expect_err("download should fail if server ignores range"); @@ -210,7 +161,7 @@ mod reqwest { } .start(&from_url, &target_path) .with_resume() - .download_to_path(None) + .download() .await .expect_err("download should fail with a connect error");