Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dist/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<'a> DownloadCfg<'a> {
let partial_file_existed = partial_file_path.exists();

let mut hasher = Sha256::new();
let download = DownloadOptions::try_from(self.process)?
let mut download = DownloadOptions::try_from(self.process)?
.start(url, &partial_file_path)
.with_hasher(&mut hasher)
.with_status(status)
Expand Down Expand Up @@ -268,7 +268,7 @@ impl<'a> DownloadCfg<'a> {
.start(&url, &file)
.with_hasher(&mut hasher);

let download = match status {
let mut download = match status {
Some(status) => download.with_status(status),
None => download,
};
Expand Down
193 changes: 68 additions & 125 deletions src/download/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Easy file downloading

use std::cell::RefCell;
use std::fs::{self, OpenOptions, remove_file};
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::num::NonZero;
Expand Down Expand Up @@ -106,15 +105,15 @@ impl TryFrom<&Process> for DownloadOptions {
pub struct Download<'a> {
url: &'a Url,
path: &'a Path,
hasher: Option<RefCell<&'a mut Sha256>>,
hasher: Option<&'a mut Sha256>,
status: Option<&'a DownloadStatus>,
resume: bool,
options: DownloadOptions,
}

impl<'a> Download<'a> {
pub(crate) fn with_hasher(mut self, hasher: &'a mut Sha256) -> Self {
self.hasher = Some(RefCell::new(hasher));
self.hasher = Some(hasher);
self
}

Expand All @@ -128,110 +127,66 @@ impl<'a> Download<'a> {
self
}

pub(crate) async fn download(&self) -> anyhow::Result<()> {
match self.download_file_().await {
Ok(_) => Ok(()),
Err(e) => {
if e.downcast_ref::<io::Error>().is_some() {
return Err(e);
}
let is_client_error = match e.downcast_ref::<DownloadError>() {
// Specifically treat the bad partial range error as not our
// fault in case it was something odd which happened.
Some(DownloadError::HttpStatus(416)) => false,
Some(DownloadError::HttpStatus(400..=499))
| Some(DownloadError::FileNotFound) => true,
_ => false,
};
Err(e).with_context(|| {
if is_client_error {
RustupError::DownloadNotExists {
url: self.url.clone(),
path: self.path.to_path_buf(),
}
} else {
RustupError::DownloadingFile {
url: self.url.clone(),
path: self.path.to_path_buf(),
}
}
})
}
}
}

async fn download_file_(&self) -> anyhow::Result<()> {
pub(crate) async fn download(&mut self) -> anyhow::Result<()> {
debug!(url = %self.url, "downloading file");

// This callback will write the download to disk and optionally
// hash the contents, then forward the notification up the stack
let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| {
if let Event::DownloadDataReceived(data) = msg
&& let Some(h) = self.hasher.as_ref()
{
h.borrow_mut().update(data);
}

match msg {
Event::DownloadContentLengthReceived(len) => {
if let Some(status) = self.status {
status.received_length(len)
}
}
Event::DownloadDataReceived(data) => {
if let Some(status) = self.status {
status.received_data(data.len())
}
}
Event::ResumingPartialDownload => debug!("resuming partial download"),
let Err(err) = self.download_impl().await else {
if let Some(status) = self.status {
status.finished();
}

Ok(())
return Ok(());
};

// Download the file

let res = self.download_to_path(Some(callback)).await;

// The notification should only be sent if the download was successful (i.e. didn't timeout)
if let Some(status) = self.status {
match &res {
Ok(_) => status.finished(),
Err(_) => status.failed(),
};
status.failed();
}

res
}

async fn download_to_path(&self, callback: Option<DownloadCallback<'_>>) -> anyhow::Result<()> {
let Err(err) = self.download_impl(callback).await else {
return Ok(());
};

// TODO: Currently, we only refrain from removing the cached download
// if there was a network failure from the client side.
// It may be worth looking for other cases where removal is also not desired.
Err(
if !(self.resume && is_network_failure(&err))
&& let Err(file_err) =
remove_file(self.path).context("cleaning up cached downloads")
{
file_err.context(err)
let e = if !(self.resume && is_network_failure(&err))
&& let Err(file_err) = remove_file(self.path).context("cleaning up cached downloads")
{
file_err.context(err)
} else {
err
};

if e.downcast_ref::<io::Error>().is_some() {
return Err(e);
}

let is_client_error = match e.downcast_ref::<DownloadError>() {
// Specifically treat the bad partial range error as not our
// fault in case it was something odd which happened.
Some(DownloadError::HttpStatus(416)) => false,
Some(DownloadError::HttpStatus(400..=499)) | Some(DownloadError::FileNotFound) => true,
_ => false,
};

Err(e).with_context(|| {
if is_client_error {
RustupError::DownloadNotExists {
url: self.url.clone(),
path: self.path.to_path_buf(),
}
} else {
err
},
)
RustupError::DownloadingFile {
url: self.url.clone(),
path: self.path.to_path_buf(),
}
}
})
}

async fn download_impl(&self, callback: Option<DownloadCallback<'_>>) -> anyhow::Result<()> {
let (file, resume_from) = if self.resume {
async fn download_impl(&mut self) -> anyhow::Result<()> {
let (mut file, resume_from) = if self.resume {
// TODO: blocking call
let possible_partial = OpenOptions::new().read(true).open(self.path);

let downloaded_so_far = if let Ok(mut partial) = possible_partial {
if let Some(cb) = callback {
cb(Event::ResumingPartialDownload)?;
if self.status.is_some() || self.hasher.is_some() {
debug!("resuming partial download");

let mut buf = vec![0; 32768];
let mut downloaded_so_far = 0;
Expand All @@ -241,7 +196,7 @@ impl<'a> Download<'a> {
if n == 0 {
break;
}
cb(Event::DownloadDataReceived(&buf[..n]))?;
self.data_received(&buf[..n]);
}

downloaded_so_far
Expand Down Expand Up @@ -276,7 +231,6 @@ impl<'a> Download<'a> {
)
};

let file = RefCell::new(file);
let client = match self.options.tls {
#[cfg(feature = "reqwest-rustls-tls")]
Tls::Rustls => rustls_client(self.options.timeout)?,
Expand All @@ -285,34 +239,18 @@ impl<'a> Download<'a> {
};

// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
self.execute(
resume_from,
&|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
.context("unable to write download to disk")?;
}
match callback {
Some(cb) => cb(event),
None => Ok(()),
}
},
client,
)
.await?;
self.execute(&mut file, resume_from, client).await?;

file.borrow_mut()
.sync_data()
file.sync_data()
.context("unable to sync download to disk")?;

Ok::<(), anyhow::Error>(())
}

async fn execute(
&self,
&mut self,
file: &mut fs::File,
resume_from: u64,
callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>,
client: &Client,
) -> anyhow::Result<()> {
// Short-circuit reqwest for the "file:" URL scheme
Expand All @@ -339,7 +277,10 @@ impl<'a> Download<'a> {
if bytes_read == 0 {
break;
}
callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;

file.write_all(&buffer[0..bytes_read])
.context("unable to write download to disk")?;
self.data_received(&buffer[0..bytes_read]);
}

return Ok(());
Expand All @@ -366,16 +307,29 @@ impl<'a> Download<'a> {

if let Some(len) = res.content_length() {
let len = len + resume_from;
callback(Event::DownloadContentLengthReceived(len))?;
if let Some(status) = self.status {
status.received_length(len);
}
}

let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let bytes = item.map_err(DownloadError::Reqwest)?;
callback(Event::DownloadDataReceived(&bytes))?;
file.write_all(&bytes)
.context("unable to write download to disk")?;
self.data_received(&bytes);
}
Ok(())
}

fn data_received(&mut self, data: &[u8]) {
if let Some(hasher) = &mut self.hasher {
hasher.update(data);
}
if let Some(status) = self.status {
status.received_data(data.len());
}
}
}

pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool {
Expand Down Expand Up @@ -407,17 +361,6 @@ enum Tls {
NativeTls,
}

#[derive(Debug, Copy, Clone)]
enum Event<'a> {
ResumingPartialDownload,
/// Received the Content-Length of the to-be downloaded data.
DownloadContentLengthReceived(u64),
/// Received some data.
DownloadDataReceived(&'a [u8]),
}

type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>;

fn client_generic() -> ClientBuilder {
Client::builder()
// HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying
Expand Down
Loading
Loading