From 5dc2eab969dc8d1ecc39399843d7812a620c3603 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 21:10:59 +0200 Subject: [PATCH 01/11] download: inline reqwest_be module contents --- src/download/mod.rs | 336 ++++++++++++++++++++------------------------ 1 file changed, 155 insertions(+), 181 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index d930d0f998..a632006dd4 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -1,23 +1,30 @@ //! Easy file downloading use std::fs::remove_file; +use std::io; use std::num::NonZero; use std::path::Path; use std::str::FromStr; +#[cfg(feature = "reqwest-rustls-tls")] +use std::sync::Arc; +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +use std::sync::OnceLock; use std::time::Duration; -use anyhow::Context; -#[cfg(any( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") -))] -use anyhow::anyhow; +use anyhow::{Context, anyhow}; +use reqwest::{Client, ClientBuilder, Proxy, Response, header}; +#[cfg(feature = "reqwest-rustls-tls")] +use rustls::crypto::aws_lc_rs; +#[cfg(feature = "reqwest-rustls-tls")] +use rustls_platform_verifier::Verifier; use sha2::Sha256; use thiserror::Error; -use tracing::debug; -use tracing::warn; +use tokio_stream::StreamExt; +use tracing::{debug, warn}; use url::Url; +#[cfg(all(feature = "reqwest-rustls-tls", not(target_os = "android")))] +use crate::anchors::RUSTUP_TRUST_ANCHORS; use crate::{dist::download::DownloadStatus, errors::RustupError, process::Process}; #[cfg(test)] @@ -44,7 +51,7 @@ pub(crate) async fn download_file_with_resume( match download_file_(url, path, hasher, resume_from_partial, status, process).await { Ok(_) => Ok(()), Err(e) => { - if e.downcast_ref::().is_some() { + if e.downcast_ref::().is_some() { return Err(e); } let is_client_error = match e.downcast_ref::() { @@ -324,13 +331,6 @@ impl Backend { Ok::<(), anyhow::Error>(()) } - #[cfg_attr( - all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") - ), - allow(unused_variables) - )] async fn download( self, url: &Url, @@ -340,12 +340,12 @@ impl Backend { ) -> anyhow::Result<()> { let client = match self { #[cfg(feature = "reqwest-rustls-tls")] - Self::Rustls => reqwest_be::rustls_client(timeout)?, + Self::Rustls => rustls_client(timeout)?, #[cfg(feature = "reqwest-native-tls")] - Self::NativeTls => reqwest_be::native_tls_client(timeout)?, + Self::NativeTls => native_tls_client(timeout)?, }; - reqwest_be::download(url, resume_from, callback, client).await + download(url, resume_from, callback, client).await } } @@ -360,191 +360,165 @@ enum Event<'a> { type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>; -#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] -mod reqwest_be { - #[cfg(feature = "reqwest-rustls-tls")] - use std::sync::Arc; - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use std::sync::OnceLock; - use std::{io, time::Duration}; - - #[cfg(all(feature = "reqwest-rustls-tls", not(target_os = "android")))] - use crate::anchors::RUSTUP_TRUST_ANCHORS; - use anyhow::{Context, anyhow}; - use reqwest::{Client, ClientBuilder, Proxy, Response, header}; - #[cfg(feature = "reqwest-rustls-tls")] - use rustls::crypto::aws_lc_rs; - #[cfg(feature = "reqwest-rustls-tls")] - use rustls_platform_verifier::Verifier; - use tokio_stream::StreamExt; - use url::Url; +async fn download( + url: &Url, + resume_from: u64, + callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, + client: &Client, +) -> anyhow::Result<()> { + // Short-circuit reqwest for the "file:" URL scheme + if download_from_file_url(url, resume_from, callback)? { + return Ok(()); + } - use super::{DownloadError, Event}; + let res = request(url, resume_from, client) + .await + .context("error downloading file")?; + + // If a download is being resumed, we expect a 206 response; + // otherwise, if the server ignored the range header, + // an error is thrown preemptively to avoid corruption. + let status = res.status().into(); + match (resume_from > 0, status) { + (true, 206) | (false, 200..=299) => {} + _ => return Err(DownloadError::HttpStatus(u32::from(status)).into()), + } - pub(super) async fn download( - url: &Url, - resume_from: u64, - callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, - client: &Client, - ) -> anyhow::Result<()> { - // Short-circuit reqwest for the "file:" URL scheme - if download_from_file_url(url, resume_from, callback)? { - return Ok(()); - } + if let Some(len) = res.content_length() { + let len = len + resume_from; + callback(Event::DownloadContentLengthReceived(len))?; + } - let res = request(url, resume_from, client) - .await - .context("error downloading file")?; - - // If a download is being resumed, we expect a 206 response; - // otherwise, if the server ignored the range header, - // an error is thrown preemptively to avoid corruption. - let status = res.status().into(); - match (resume_from > 0, status) { - (true, 206) | (false, 200..=299) => {} - _ => return Err(DownloadError::HttpStatus(u32::from(status)).into()), - } + let mut stream = res.bytes_stream(); + while let Some(item) = stream.next().await { + let bytes = item.map_err(DownloadError::Reqwest)?; + callback(Event::DownloadDataReceived(&bytes))?; + } + Ok(()) +} - if let Some(len) = res.content_length() { - let len = len + resume_from; - callback(Event::DownloadContentLengthReceived(len))?; - } +fn client_generic() -> ClientBuilder { + Client::builder() + // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying + // `hyper` library that causes the `reqwest` client to hang in some cases. + // See for more details. + .pool_max_idle_per_host(0) + .gzip(false) + .proxy(Proxy::custom(env_proxy)) +} - let mut stream = res.bytes_stream(); - while let Some(item) = stream.next().await { - let bytes = item.map_err(DownloadError::Reqwest)?; - callback(Event::DownloadDataReceived(&bytes))?; - } - Ok(()) +#[cfg(feature = "reqwest-rustls-tls")] +fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { + // If the client is already initialized, the passed timeout is ignored. + if let Some(client) = CLIENT_RUSTLS_TLS.get() { + return Ok(client); } - fn client_generic() -> ClientBuilder { - Client::builder() - // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying - // `hyper` library that causes the `reqwest` client to hang in some cases. - // See for more details. - .pool_max_idle_per_host(0) - .gzip(false) - .proxy(Proxy::custom(env_proxy)) - } + let provider = Arc::new(aws_lc_rs::default_provider()); + #[cfg(not(target_os = "android"))] + let result = + Verifier::new_with_extra_roots(RUSTUP_TRUST_ANCHORS.iter().cloned(), provider.clone()); + #[cfg(target_os = "android")] + let result = Verifier::new(provider.clone()); + let verifier = result.map_err(|err| { + DownloadError::Message(format!("failed to initialize platform verifier: {err}")) + })?; + + let mut tls_config = rustls::ClientConfig::builder_with_provider(provider) + .with_safe_default_protocol_versions() + .unwrap() + .dangerous() // We're using a rustls verifier, so it's okay + .with_custom_certificate_verifier(Arc::new(verifier)) + .with_no_client_auth(); + tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + let client = client_generic() + .read_timeout(timeout) + .use_preconfigured_tls(tls_config) + .user_agent(REQWEST_RUSTLS_TLS_USER_AGENT) + .build() + .map_err(DownloadError::Reqwest)?; + + let _ = CLIENT_RUSTLS_TLS.set(client); + // "The cell is guaranteed to contain a value when `set` returns, though not necessarily + // the one provided." + Ok(CLIENT_RUSTLS_TLS.get().unwrap()) +} - #[cfg(feature = "reqwest-rustls-tls")] - pub(super) fn rustls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { - // If the client is already initialized, the passed timeout is ignored. - if let Some(client) = CLIENT_RUSTLS_TLS.get() { - return Ok(client); - } +#[cfg(feature = "reqwest-rustls-tls")] +static CLIENT_RUSTLS_TLS: OnceLock = OnceLock::new(); - let provider = Arc::new(aws_lc_rs::default_provider()); - #[cfg(not(target_os = "android"))] - let result = - Verifier::new_with_extra_roots(RUSTUP_TRUST_ANCHORS.iter().cloned(), provider.clone()); - #[cfg(target_os = "android")] - let result = Verifier::new(provider.clone()); - let verifier = result.map_err(|err| { - DownloadError::Message(format!("failed to initialize platform verifier: {err}")) - })?; - - let mut tls_config = rustls::ClientConfig::builder_with_provider(provider) - .with_safe_default_protocol_versions() - .unwrap() - .dangerous() // We're using a rustls verifier, so it's okay - .with_custom_certificate_verifier(Arc::new(verifier)) - .with_no_client_auth(); - tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - let client = client_generic() - .read_timeout(timeout) - .use_preconfigured_tls(tls_config) - .user_agent(super::REQWEST_RUSTLS_TLS_USER_AGENT) - .build() - .map_err(DownloadError::Reqwest)?; - - let _ = CLIENT_RUSTLS_TLS.set(client); - // "The cell is guaranteed to contain a value when `set` returns, though not necessarily - // the one provided." - Ok(CLIENT_RUSTLS_TLS.get().unwrap()) +#[cfg(feature = "reqwest-native-tls")] +fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { + // If the client is already initialized, the passed timeout is ignored. + if let Some(client) = CLIENT_NATIVE_TLS.get() { + return Ok(client); } - #[cfg(feature = "reqwest-rustls-tls")] - static CLIENT_RUSTLS_TLS: OnceLock = OnceLock::new(); + let client = client_generic() + .read_timeout(timeout) + .user_agent(REQWEST_DEFAULT_TLS_USER_AGENT) + .build() + .map_err(DownloadError::Reqwest)?; - #[cfg(feature = "reqwest-native-tls")] - pub(super) fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError> { - // If the client is already initialized, the passed timeout is ignored. - if let Some(client) = CLIENT_NATIVE_TLS.get() { - return Ok(client); - } + let _ = CLIENT_NATIVE_TLS.set(client); + // "The cell is guaranteed to contain a value when `set` returns, though not necessarily + // the one provided." + Ok(CLIENT_NATIVE_TLS.get().unwrap()) +} - let client = client_generic() - .read_timeout(timeout) - .user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT) - .build() - .map_err(DownloadError::Reqwest)?; +#[cfg(feature = "reqwest-native-tls")] +static CLIENT_NATIVE_TLS: OnceLock = OnceLock::new(); - let _ = CLIENT_NATIVE_TLS.set(client); - // "The cell is guaranteed to contain a value when `set` returns, though not necessarily - // the one provided." - Ok(CLIENT_NATIVE_TLS.get().unwrap()) - } +fn env_proxy(url: &Url) -> Option { + env_proxy::for_url(url).to_url() +} - #[cfg(feature = "reqwest-native-tls")] - static CLIENT_NATIVE_TLS: OnceLock = OnceLock::new(); +async fn request(url: &Url, resume_from: u64, client: &Client) -> Result { + let mut req = client.get(url.as_str()); - fn env_proxy(url: &Url) -> Option { - env_proxy::for_url(url).to_url() + if resume_from != 0 { + req = req.header(header::RANGE, format!("bytes={resume_from}-")); } - async fn request( - url: &Url, - resume_from: u64, - client: &Client, - ) -> Result { - let mut req = client.get(url.as_str()); + Ok(req.send().await?) +} - if resume_from != 0 { - req = req.header(header::RANGE, format!("bytes={resume_from}-")); +fn download_from_file_url( + url: &Url, + resume_from: u64, + callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, +) -> anyhow::Result { + use std::fs; + + // The file scheme is mostly for use by tests to mock the dist server + if url.scheme() == "file" { + let src = url + .to_file_path() + .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; + if !src.is_file() { + // Because some of rustup's logic depends on checking + // the error when a downloaded file doesn't exist, make + // the file case return the same error value as the + // network case. + return Err(anyhow!(DownloadError::FileNotFound)); } - Ok(req.send().await?) - } - - fn download_from_file_url( - url: &Url, - resume_from: u64, - callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, - ) -> anyhow::Result { - use std::fs; - - // The file scheme is mostly for use by tests to mock the dist server - if url.scheme() == "file" { - let src = url - .to_file_path() - .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; - if !src.is_file() { - // Because some of rustup's logic depends on checking - // the error when a downloaded file doesn't exist, make - // the file case return the same error value as the - // network case. - return Err(anyhow!(DownloadError::FileNotFound)); - } - - let mut f = fs::File::open(src).context("unable to open downloaded file")?; - io::Seek::seek(&mut f, io::SeekFrom::Start(resume_from))?; + let mut f = fs::File::open(src).context("unable to open downloaded file")?; + io::Seek::seek(&mut f, io::SeekFrom::Start(resume_from))?; - let mut buffer = vec![0u8; 0x10000]; - loop { - let bytes_read = io::Read::read(&mut f, &mut buffer)?; - if bytes_read == 0 { - break; - } - callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; + let mut buffer = vec![0u8; 0x10000]; + loop { + let bytes_read = io::Read::read(&mut f, &mut buffer)?; + if bytes_read == 0 { + break; } - - Ok(true) - } else { - Ok(false) + callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; } + + Ok(true) + } else { + Ok(false) } } @@ -557,7 +531,7 @@ enum DownloadError { #[error("{0}")] Message(String), #[error(transparent)] - IoError(#[from] std::io::Error), + IoError(#[from] io::Error), #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] #[error(transparent)] Reqwest(#[from] ::reqwest::Error), From 1e1913e780befc4d303e2f68b722f7fe5b6c90ee Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 21:13:57 +0200 Subject: [PATCH 02/11] download: move all imports to the top --- src/download/mod.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index a632006dd4..63017abee8 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -1,7 +1,8 @@ //! Easy file downloading -use std::fs::remove_file; -use std::io; +use std::cell::RefCell; +use std::fs::{self, OpenOptions, remove_file}; +use std::io::{self, Read, Seek, SeekFrom, Write}; use std::num::NonZero; use std::path::Path; use std::str::FromStr; @@ -17,7 +18,7 @@ use reqwest::{Client, ClientBuilder, Proxy, Response, header}; use rustls::crypto::aws_lc_rs; #[cfg(feature = "reqwest-rustls-tls")] use rustls_platform_verifier::Verifier; -use sha2::Sha256; +use sha2::{Digest, Sha256}; use thiserror::Error; use tokio_stream::StreamExt; use tracing::{debug, warn}; @@ -96,11 +97,6 @@ async fn download_file_( status: Option<&DownloadStatus>, process: &Process, ) -> anyhow::Result<()> { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - use crate::download::{Backend, Event}; - use sha2::Digest; - use std::cell::RefCell; - debug!(url = %url, "downloading file"); let hasher = RefCell::new(hasher); @@ -253,10 +249,6 @@ impl Backend { callback: Option>, timeout: Duration, ) -> anyhow::Result<()> { - use std::cell::RefCell; - use std::fs::OpenOptions; - use std::io::{Read, Seek, SeekFrom, Write}; - let (file, resume_from) = if resume_from_partial { // TODO: blocking call let possible_partial = OpenOptions::new().read(true).open(path); @@ -489,8 +481,6 @@ fn download_from_file_url( resume_from: u64, callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, ) -> anyhow::Result { - use std::fs; - // The file scheme is mostly for use by tests to mock the dist server if url.scheme() == "file" { let src = url @@ -505,11 +495,11 @@ fn download_from_file_url( } let mut f = fs::File::open(src).context("unable to open downloaded file")?; - io::Seek::seek(&mut f, io::SeekFrom::Start(resume_from))?; + Seek::seek(&mut f, SeekFrom::Start(resume_from))?; let mut buffer = vec![0u8; 0x10000]; loop { - let bytes_read = io::Read::read(&mut f, &mut buffer)?; + let bytes_read = Read::read(&mut f, &mut buffer)?; if bytes_read == 0 { break; } From 5ab9e2c60d89ad5bb8b9ea567bb65fa7de847498 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 21:15:22 +0200 Subject: [PATCH 03/11] download: inline single-caller download_from_file_url() helper --- src/download/mod.rs | 62 ++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index 63017abee8..5b2d181a80 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -359,7 +359,31 @@ async fn download( client: &Client, ) -> anyhow::Result<()> { // Short-circuit reqwest for the "file:" URL scheme - if download_from_file_url(url, resume_from, callback)? { + // The file scheme is mostly for use by tests to mock the dist server + if url.scheme() == "file" { + let src = url + .to_file_path() + .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; + if !src.is_file() { + // Because some of rustup's logic depends on checking + // the error when a downloaded file doesn't exist, make + // the file case return the same error value as the + // network case. + return Err(anyhow!(DownloadError::FileNotFound)); + } + + let mut f = fs::File::open(src).context("unable to open downloaded file")?; + Seek::seek(&mut f, SeekFrom::Start(resume_from))?; + + let mut buffer = vec![0u8; 0x10000]; + loop { + let bytes_read = Read::read(&mut f, &mut buffer)?; + if bytes_read == 0 { + break; + } + callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; + } + return Ok(()); } @@ -476,42 +500,6 @@ async fn request(url: &Url, resume_from: u64, client: &Client) -> Result) -> anyhow::Result<()>, -) -> anyhow::Result { - // The file scheme is mostly for use by tests to mock the dist server - if url.scheme() == "file" { - let src = url - .to_file_path() - .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; - if !src.is_file() { - // Because some of rustup's logic depends on checking - // the error when a downloaded file doesn't exist, make - // the file case return the same error value as the - // network case. - return Err(anyhow!(DownloadError::FileNotFound)); - } - - let mut f = fs::File::open(src).context("unable to open downloaded file")?; - Seek::seek(&mut f, SeekFrom::Start(resume_from))?; - - let mut buffer = vec![0u8; 0x10000]; - loop { - let bytes_read = Read::read(&mut f, &mut buffer)?; - if bytes_read == 0 { - break; - } - callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; - } - - Ok(true) - } else { - Ok(false) - } -} - #[derive(Debug, Error)] enum DownloadError { #[error("http request returned an unsuccessful status code: {0}")] From df74cf997a6f3d761194168531fa8fb496bac2d1 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 21:17:04 +0200 Subject: [PATCH 04/11] download: inline single-caller request() helper --- src/download/mod.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index 5b2d181a80..0185dc6dd2 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -13,7 +13,7 @@ use std::sync::OnceLock; use std::time::Duration; use anyhow::{Context, anyhow}; -use reqwest::{Client, ClientBuilder, Proxy, Response, header}; +use reqwest::{Client, ClientBuilder, Proxy, header}; #[cfg(feature = "reqwest-rustls-tls")] use rustls::crypto::aws_lc_rs; #[cfg(feature = "reqwest-rustls-tls")] @@ -387,8 +387,14 @@ async fn download( return Ok(()); } - let res = request(url, resume_from, client) + let mut req = client.get(url.as_str()); + if resume_from != 0 { + req = req.header(header::RANGE, format!("bytes={resume_from}-")); + } + let res = req + .send() .await + .map_err(DownloadError::Reqwest) .context("error downloading file")?; // If a download is being resumed, we expect a 206 response; @@ -490,16 +496,6 @@ fn env_proxy(url: &Url) -> Option { env_proxy::for_url(url).to_url() } -async fn request(url: &Url, resume_from: u64, client: &Client) -> Result { - let mut req = client.get(url.as_str()); - - if resume_from != 0 { - req = req.header(header::RANGE, format!("bytes={resume_from}-")); - } - - Ok(req.send().await?) -} - #[derive(Debug, Error)] enum DownloadError { #[error("http request returned an unsuccessful status code: {0}")] From 01facab799f33d7de51bd961909290d93523ab74 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 21:18:20 +0200 Subject: [PATCH 05/11] download: inline single-caller env_proxy() helper --- src/download/mod.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index 0185dc6dd2..b4b5bf1e8a 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -426,7 +426,7 @@ fn client_generic() -> ClientBuilder { // See for more details. .pool_max_idle_per_host(0) .gzip(false) - .proxy(Proxy::custom(env_proxy)) + .proxy(Proxy::custom(|url| env_proxy::for_url(url).to_url())) } #[cfg(feature = "reqwest-rustls-tls")] @@ -492,10 +492,6 @@ fn native_tls_client(timeout: Duration) -> Result<&'static Client, DownloadError #[cfg(feature = "reqwest-native-tls")] static CLIENT_NATIVE_TLS: OnceLock = OnceLock::new(); -fn env_proxy(url: &Url) -> Option { - env_proxy::for_url(url).to_url() -} - #[derive(Debug, Error)] enum DownloadError { #[error("http request returned an unsuccessful status code: {0}")] From 9f6e91d037f92a1b46895db57f56d9e29bb2116c Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 21:56:17 +0200 Subject: [PATCH 06/11] download: use builder pattern to bundle up download input --- src/cli/self_update.rs | 11 +- src/cli/self_update/windows.rs | 13 +- src/dist/download.rs | 32 ++-- src/dist/manifestation/tests.rs | 6 +- src/download/mod.rs | 283 +++++++++++++++++--------------- 5 files changed, 180 insertions(+), 165 deletions(-) diff --git a/src/cli/self_update.rs b/src/cli/self_update.rs index d23d82bd62..91644fd8a6 100644 --- a/src/cli/self_update.rs +++ b/src/cli/self_update.rs @@ -63,7 +63,7 @@ use crate::{ DistOptions, PartialToolchainDesc, Profile, TargetTuple, ToolchainDesc, download::DownloadCfg, }, - download::download_file, + download::Download, errors::RustupError, install::{InstallMethod, UpdateStatus}, process::Process, @@ -1340,7 +1340,9 @@ pub(crate) async fn prepare_update(dl_cfg: &DownloadCfg<'_>) -> Result) -> Result(&release_toml_str) .context("unable to parse rustup release file")?; diff --git a/src/cli/self_update/windows.rs b/src/cli/self_update/windows.rs index 5f3eb5e0bd..9076b85e22 100644 --- a/src/cli/self_update/windows.rs +++ b/src/cli/self_update/windows.rs @@ -23,7 +23,7 @@ use crate::cli::markdown::md; use crate::config::Cfg; use crate::dist::TargetTuple; use crate::dist::download::DownloadCfg; -use crate::download::download_file; +use crate::download::Download; use crate::process::{ColorableTerminal, Process}; use crate::utils; @@ -273,14 +273,9 @@ pub(crate) async fn try_install_msvc( let visual_studio = tempdir.path().join("vs_setup.exe"); let dl_cfg = DownloadCfg::new(cfg); info!("downloading Visual Studio installer"); - download_file( - &visual_studio_url, - &visual_studio, - None, - None, - dl_cfg.process, - ) - .await?; + Download::new(&visual_studio_url, &visual_studio, dl_cfg.process) + .download() + .await?; // Run the installer. Arguments are documented at: // https://docs.microsoft.com/en-us/visualstudio/install/use-command-line-parameters-to-install-visual-studio diff --git a/src/dist/download.rs b/src/dist/download.rs index 39283f7f38..d7910031e3 100644 --- a/src/dist/download.rs +++ b/src/dist/download.rs @@ -15,7 +15,7 @@ use url::Url; use crate::config::Cfg; use crate::dist::manifest::{Manifest, ManifestWithHash}; use crate::dist::{Channel, DEFAULT_DIST_SERVER, ToolchainDesc, temp}; -use crate::download::{download_file, download_file_with_resume, is_network_failure}; +use crate::download::{Download, is_network_failure}; use crate::errors::RustupError; use crate::process::Process; use crate::utils; @@ -82,17 +82,12 @@ impl<'a> DownloadCfg<'a> { let partial_file_existed = partial_file_path.exists(); let mut hasher = Sha256::new(); + let download = Download::new(url, &partial_file_path, self.process) + .with_hasher(&mut hasher) + .with_status(status) + .with_resume(); - if let Err(e) = download_file_with_resume( - url, - &partial_file_path, - Some(&mut hasher), - true, - Some(status), - self.process, - ) - .await - { + if let Err(e) = download.download().await { let is_network_failure = is_network_failure(&e); let err = Err(e); return match (partial_file_existed, is_network_failure) { @@ -142,9 +137,9 @@ impl<'a> DownloadCfg<'a> { async fn download_hash(&self, url: &str) -> Result { let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?; let hash_file = self.tmp_cx.new_file()?; - - download_file(&hash_url, &hash_file, None, None, self.process).await?; - + Download::new(&hash_url, &hash_file, self.process) + .download() + .await?; utils::read_file("hash", &hash_file).map(|s| s[0..64].to_owned()) } @@ -267,7 +262,14 @@ impl<'a> DownloadCfg<'a> { let file = self.tmp_cx.new_file_with_ext("", ext)?; let mut hasher = Sha256::new(); - download_file(&url, &file, Some(&mut hasher), status, self.process).await?; + let download = Download::new(&url, &file, self.process).with_hasher(&mut hasher); + + let download = match status { + Some(status) => download.with_status(status), + None => download, + }; + + download.download().await?; let actual_hash = faster_hex::hex_string(&hasher.finalize()); if hash != actual_hash { diff --git a/src/dist/manifestation/tests.rs b/src/dist/manifestation/tests.rs index 00691c29c1..de01b49d8e 100644 --- a/src/dist/manifestation/tests.rs +++ b/src/dist/manifestation/tests.rs @@ -22,7 +22,7 @@ use crate::{ prefix::InstallPrefix, temp, }, - download::download_file, + download::Download, errors::RustupError, process::TestProcess, test::{ @@ -490,7 +490,9 @@ impl TestContext { // Download the dist manifest and place it into the installation prefix let manifest_url = make_manifest_url(&self.url, &self.toolchain)?; let manifest_file = self.tmp_cx.new_file()?; - download_file(&manifest_url, &manifest_file, None, None, dl_cfg.process).await?; + Download::new(&manifest_url, &manifest_file, dl_cfg.process) + .download() + .await?; let manifest_str = utils::read_file("manifest", &manifest_file)?; let manifest = Manifest::parse(&manifest_str)?; diff --git a/src/download/mod.rs b/src/download/mod.rs index b4b5bf1e8a..b23ef39bdc 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -31,163 +31,174 @@ use crate::{dist::download::DownloadStatus, errors::RustupError, process::Proces #[cfg(test)] mod tests; -pub(crate) async fn download_file( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - download_file_with_resume(url, path, hasher, false, status, process).await +pub struct Download<'a> { + url: &'a Url, + path: &'a Path, + hasher: Option>, + status: Option<&'a DownloadStatus>, + resume: bool, + process: &'a Process, } -pub(crate) async fn download_file_with_resume( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - match download_file_(url, path, hasher, resume_from_partial, status, process).await { - Ok(_) => Ok(()), - Err(e) => { - if e.downcast_ref::().is_some() { - return Err(e); - } - let is_client_error = match e.downcast_ref::() { - // Specifically treat the bad partial range error as not our - // fault in case it was something odd which happened. - Some(DownloadError::HttpStatus(416)) => false, - Some(DownloadError::HttpStatus(400..=499)) | Some(DownloadError::FileNotFound) => { - true - } - _ => false, - }; - Err(e).with_context(|| { - if is_client_error { - RustupError::DownloadNotExists { - url: url.clone(), - path: path.to_path_buf(), - } - } else { - RustupError::DownloadingFile { - url: url.clone(), - path: path.to_path_buf(), - } - } - }) +impl<'a> Download<'a> { + pub(crate) fn new(url: &'a Url, path: &'a Path, process: &'a Process) -> Self { + Self { + url, + path, + hasher: None, + status: None, + resume: false, + process, } } -} -pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { - match err.downcast_ref::() { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - Some(DownloadError::Reqwest(e)) => e.is_timeout() || e.is_connect(), - _ => false, + pub(crate) fn with_hasher(mut self, hasher: &'a mut Sha256) -> Self { + self.hasher = Some(RefCell::new(hasher)); + self } -} -async fn download_file_( - url: &Url, - path: &Path, - hasher: Option<&mut Sha256>, - resume_from_partial: bool, - status: Option<&DownloadStatus>, - process: &Process, -) -> anyhow::Result<()> { - debug!(url = %url, "downloading file"); - let hasher = RefCell::new(hasher); - - // This callback will write the download to disk and optionally - // hash the contents, then forward the notification up the stack - let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { - if let Event::DownloadDataReceived(data) = msg - && let Some(h) = hasher.borrow_mut().as_mut() - { - h.update(data); - } + pub(crate) fn with_status(mut self, status: &'a DownloadStatus) -> Self { + self.status = Some(status); + self + } - match msg { - Event::DownloadContentLengthReceived(len) => { - if let Some(status) = status { - status.received_length(len) + pub(crate) fn with_resume(mut self) -> Self { + self.resume = true; + self + } + + pub(crate) async fn download(&self) -> anyhow::Result<()> { + match self.download_file_().await { + Ok(_) => Ok(()), + Err(e) => { + if e.downcast_ref::().is_some() { + return Err(e); } + let is_client_error = match e.downcast_ref::() { + // Specifically treat the bad partial range error as not our + // fault in case it was something odd which happened. + Some(DownloadError::HttpStatus(416)) => false, + Some(DownloadError::HttpStatus(400..=499)) + | Some(DownloadError::FileNotFound) => true, + _ => false, + }; + Err(e).with_context(|| { + if is_client_error { + RustupError::DownloadNotExists { + url: self.url.clone(), + path: self.path.to_path_buf(), + } + } else { + RustupError::DownloadingFile { + url: self.url.clone(), + path: self.path.to_path_buf(), + } + } + }) + } + } + } + + async fn download_file_(&self) -> anyhow::Result<()> { + debug!(url = %self.url, "downloading file"); + + // This callback will write the download to disk and optionally + // hash the contents, then forward the notification up the stack + let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| { + if let Event::DownloadDataReceived(data) = msg + && let Some(h) = self.hasher.as_ref() + { + h.borrow_mut().update(data); } - Event::DownloadDataReceived(data) => { - if let Some(status) = status { - status.received_data(data.len()) + + match msg { + Event::DownloadContentLengthReceived(len) => { + if let Some(status) = self.status { + status.received_length(len) + } + } + Event::DownloadDataReceived(data) => { + if let Some(status) = self.status { + status.received_data(data.len()) + } } + Event::ResumingPartialDownload => debug!("resuming partial download"), } - Event::ResumingPartialDownload => debug!("resuming partial download"), - } - Ok(()) - }; + Ok(()) + }; - // Download the file + // Download the file - let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); - if use_rustls == Some(false) { - warn!( - "RUSTUP_USE_RUSTLS is set to `0`; the native-tls backend is deprecated, + let use_rustls = self.process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); + if use_rustls == Some(false) { + warn!( + "RUSTUP_USE_RUSTLS is set to `0`; the native-tls backend is deprecated, please file an issue if the default download backend does not work for your use case" - ); - } - - let backend = match use_rustls { - // If the environment explicitly selects a TLS backend that's unavailable, error out. - #[cfg(not(feature = "reqwest-rustls-tls"))] - Some(true) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" - )); - } - #[cfg(not(feature = "reqwest-native-tls"))] - Some(false) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" - )); + ); } - // Prefer explicit selections before falling back to the default TLS stack. - #[cfg(feature = "reqwest-native-tls")] - Some(false) => Backend::NativeTls, - - // The default fallback is `rustls`, which should be used whenever available. - #[cfg(feature = "reqwest-rustls-tls")] - _ => Backend::Rustls, - - // The `rustls` feature is disabled, fall back to `native-tls` instead. - #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - _ => Backend::NativeTls, - }; - - let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") { - Ok(s) => NonZero::from_str(&s) - .context( - "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", - )? - .get(), - Err(_) => 180, - }); - - debug!("downloading with reqwest"); - - let res = backend - .download_to_path(url, path, resume_from_partial, Some(callback), timeout) - .await; - - // The notification should only be sent if the download was successful (i.e. didn't timeout) - if let Some(status) = status { - match &res { - Ok(_) => status.finished(), - Err(_) => status.failed(), + let backend = match use_rustls { + // If the environment explicitly selects a TLS backend that's unavailable, error out. + #[cfg(not(feature = "reqwest-rustls-tls"))] + Some(true) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" + )); + } + #[cfg(not(feature = "reqwest-native-tls"))] + Some(false) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" + )); + } + + // Prefer explicit selections before falling back to the default TLS stack. + #[cfg(feature = "reqwest-native-tls")] + Some(false) => Backend::NativeTls, + + // The default fallback is `rustls`, which should be used whenever available. + #[cfg(feature = "reqwest-rustls-tls")] + _ => Backend::Rustls, + + // The `rustls` feature is disabled, fall back to `native-tls` instead. + #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] + _ => Backend::NativeTls, }; + + let timeout = Duration::from_secs(match self.process.var("RUSTUP_DOWNLOAD_TIMEOUT") { + Ok(s) => NonZero::from_str(&s) + .context( + "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", + )? + .get(), + Err(_) => 180, + }); + + debug!("downloading with reqwest"); + + let res = backend + .download_to_path(self.url, self.path, self.resume, Some(callback), timeout) + .await; + + // The notification should only be sent if the download was successful (i.e. didn't timeout) + if let Some(status) = self.status { + match &res { + Ok(_) => status.finished(), + Err(_) => status.failed(), + }; + } + + res } +} - res +pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { + match err.downcast_ref::() { + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + Some(DownloadError::Reqwest(e)) => e.is_timeout() || e.is_connect(), + _ => false, + } } /// User agent header value for HTTP request. From 569c432ad3801dd97c89feb59448db2a7b834da6 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 22:06:40 +0200 Subject: [PATCH 07/11] download: rename Backend to Tls --- src/download/mod.rs | 14 +++++++------- src/download/tests.rs | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index b23ef39bdc..c4fc3070a8 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -138,7 +138,7 @@ impl<'a> Download<'a> { ); } - let backend = match use_rustls { + let tls = match use_rustls { // If the environment explicitly selects a TLS backend that's unavailable, error out. #[cfg(not(feature = "reqwest-rustls-tls"))] Some(true) => { @@ -155,15 +155,15 @@ impl<'a> Download<'a> { // Prefer explicit selections before falling back to the default TLS stack. #[cfg(feature = "reqwest-native-tls")] - Some(false) => Backend::NativeTls, + Some(false) => Tls::NativeTls, // The default fallback is `rustls`, which should be used whenever available. #[cfg(feature = "reqwest-rustls-tls")] - _ => Backend::Rustls, + _ => Tls::Rustls, // The `rustls` feature is disabled, fall back to `native-tls` instead. #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - _ => Backend::NativeTls, + _ => Tls::NativeTls, }; let timeout = Duration::from_secs(match self.process.var("RUSTUP_DOWNLOAD_TIMEOUT") { @@ -177,7 +177,7 @@ impl<'a> Download<'a> { debug!("downloading with reqwest"); - let res = backend + let res = tls .download_to_path(self.url, self.path, self.resume, Some(callback), timeout) .await; @@ -215,14 +215,14 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str = concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)"); #[derive(Debug, Copy, Clone)] -enum Backend { +enum Tls { #[cfg(feature = "reqwest-rustls-tls")] Rustls, #[cfg(feature = "reqwest-native-tls")] NativeTls, } -impl Backend { +impl Tls { async fn download_to_path( self, url: &Url, diff --git a/src/download/tests.rs b/src/download/tests.rs index 12d4760b9f..2c9c7aef45 100644 --- a/src/download/tests.rs +++ b/src/download/tests.rs @@ -30,13 +30,13 @@ mod reqwest { use url::Url; use super::{scrub_env, serve_file, tmp_dir, write_file}; - use crate::download::{Backend, Event}; + use crate::download::{Event, Tls}; #[cfg(feature = "reqwest-rustls-tls")] - const DOWNLOAD_BACKEND: Backend = Backend::Rustls; + const DOWNLOAD_BACKEND: Tls = Tls::Rustls; #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - const DOWNLOAD_BACKEND: Backend = Backend::NativeTls; + const DOWNLOAD_BACKEND: Tls = Tls::NativeTls; // Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy #[tokio::test] From 6ff87673ac8d00de2f97b9b9a4d7585e5832d56c Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 22:54:00 +0200 Subject: [PATCH 08/11] self_update: avoid unnecessary DownloadCfg setup --- src/cli/self_update/windows.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/cli/self_update/windows.rs b/src/cli/self_update/windows.rs index 9076b85e22..aca1b44aaa 100644 --- a/src/cli/self_update/windows.rs +++ b/src/cli/self_update/windows.rs @@ -22,7 +22,6 @@ use super::{InstallOpts, install_bins, report_error}; use crate::cli::markdown::md; use crate::config::Cfg; use crate::dist::TargetTuple; -use crate::dist::download::DownloadCfg; use crate::download::Download; use crate::process::{ColorableTerminal, Process}; use crate::utils; @@ -271,9 +270,8 @@ pub(crate) async fn try_install_msvc( .context("error creating temp directory")?; let visual_studio = tempdir.path().join("vs_setup.exe"); - let dl_cfg = DownloadCfg::new(cfg); info!("downloading Visual Studio installer"); - Download::new(&visual_studio_url, &visual_studio, dl_cfg.process) + Download::new(&visual_studio_url, &visual_studio, cfg.process) .download() .await?; From f51ceacbb063ac3b2cc3c44136e761318e1caf56 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 22:17:15 +0200 Subject: [PATCH 09/11] download: extract DownloadOptions type --- src/cli/self_update.rs | 8 +- src/cli/self_update/windows.rs | 5 +- src/dist/download.rs | 12 ++- src/dist/manifestation/tests.rs | 5 +- src/download/mod.rs | 142 ++++++++++++++++++-------------- 5 files changed, 101 insertions(+), 71 deletions(-) diff --git a/src/cli/self_update.rs b/src/cli/self_update.rs index 91644fd8a6..137b5e71e2 100644 --- a/src/cli/self_update.rs +++ b/src/cli/self_update.rs @@ -63,7 +63,7 @@ use crate::{ DistOptions, PartialToolchainDesc, Profile, TargetTuple, ToolchainDesc, download::DownloadCfg, }, - download::Download, + download::DownloadOptions, errors::RustupError, install::{InstallMethod, UpdateStatus}, process::Process, @@ -1340,7 +1340,8 @@ pub(crate) async fn prepare_update(dl_cfg: &DownloadCfg<'_>) -> Result) -> Result DownloadCfg<'a> { let partial_file_existed = partial_file_path.exists(); let mut hasher = Sha256::new(); - let download = Download::new(url, &partial_file_path, self.process) + let download = DownloadOptions::try_from(self.process)? + .start(url, &partial_file_path) .with_hasher(&mut hasher) .with_status(status) .with_resume(); @@ -137,7 +138,8 @@ impl<'a> DownloadCfg<'a> { async fn download_hash(&self, url: &str) -> Result { let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?; let hash_file = self.tmp_cx.new_file()?; - Download::new(&hash_url, &hash_file, self.process) + DownloadOptions::try_from(self.process)? + .start(&hash_url, &hash_file) .download() .await?; utils::read_file("hash", &hash_file).map(|s| s[0..64].to_owned()) @@ -262,7 +264,9 @@ impl<'a> DownloadCfg<'a> { let file = self.tmp_cx.new_file_with_ext("", ext)?; let mut hasher = Sha256::new(); - let download = Download::new(&url, &file, self.process).with_hasher(&mut hasher); + let download = DownloadOptions::try_from(self.process)? + .start(&url, &file) + .with_hasher(&mut hasher); let download = match status { Some(status) => download.with_status(status), diff --git a/src/dist/manifestation/tests.rs b/src/dist/manifestation/tests.rs index de01b49d8e..232e8eb466 100644 --- a/src/dist/manifestation/tests.rs +++ b/src/dist/manifestation/tests.rs @@ -22,7 +22,7 @@ use crate::{ prefix::InstallPrefix, temp, }, - download::Download, + download::DownloadOptions, errors::RustupError, process::TestProcess, test::{ @@ -490,7 +490,8 @@ impl TestContext { // Download the dist manifest and place it into the installation prefix let manifest_url = make_manifest_url(&self.url, &self.toolchain)?; let manifest_file = self.tmp_cx.new_file()?; - Download::new(&manifest_url, &manifest_file, dl_cfg.process) + DownloadOptions::try_from(dl_cfg.process)? + .start(&manifest_url, &manifest_file) .download() .await?; let manifest_str = utils::read_file("manifest", &manifest_file)?; diff --git a/src/download/mod.rs b/src/download/mod.rs index c4fc3070a8..6beab9e717 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -31,27 +31,88 @@ use crate::{dist::download::DownloadStatus, errors::RustupError, process::Proces #[cfg(test)] mod tests; -pub struct Download<'a> { - url: &'a Url, - path: &'a Path, - hasher: Option>, - status: Option<&'a DownloadStatus>, - resume: bool, - process: &'a Process, +#[derive(Debug, Clone, Copy)] +pub struct DownloadOptions { + tls: Tls, + timeout: Duration, } -impl<'a> Download<'a> { - pub(crate) fn new(url: &'a Url, path: &'a Path, process: &'a Process) -> Self { - Self { +impl DownloadOptions { + pub fn start<'a>(&self, url: &'a Url, path: &'a Path) -> Download<'a> { + Download { url, path, hasher: None, status: None, resume: false, - process, + options: *self, } } +} + +impl TryFrom<&Process> for DownloadOptions { + type Error = anyhow::Error; + + fn try_from(process: &Process) -> Result { + let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); + if use_rustls == Some(false) { + warn!( + "RUSTUP_USE_RUSTLS is set to `0`; the native-tls backend is deprecated, + please file an issue if the default download backend does not work for your use case" + ); + } + + let tls = match use_rustls { + // If the environment explicitly selects a TLS backend that's unavailable, error out. + #[cfg(not(feature = "reqwest-rustls-tls"))] + Some(true) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" + )); + } + #[cfg(not(feature = "reqwest-native-tls"))] + Some(false) => { + return Err(anyhow!( + "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" + )); + } + + // Prefer explicit selections before falling back to the default TLS stack. + #[cfg(feature = "reqwest-native-tls")] + Some(false) => Tls::NativeTls, + + // The default fallback is `rustls`, which should be used whenever available. + #[cfg(feature = "reqwest-rustls-tls")] + _ => Tls::Rustls, + + // The `rustls` feature is disabled, fall back to `native-tls` instead. + #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] + _ => Tls::NativeTls, + }; + + let timeout = Duration::from_secs(match process.var("RUSTUP_DOWNLOAD_TIMEOUT") { + Ok(s) => NonZero::from_str(&s) + .context( + "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", + )? + .get(), + Err(_) => 180, + }); + + Ok(Self { tls, timeout }) + } +} +pub struct Download<'a> { + url: &'a Url, + path: &'a Path, + hasher: Option>, + status: Option<&'a DownloadStatus>, + resume: bool, + options: DownloadOptions, +} + +impl<'a> Download<'a> { pub(crate) fn with_hasher(mut self, hasher: &'a mut Sha256) -> Self { self.hasher = Some(RefCell::new(hasher)); self @@ -130,55 +191,16 @@ impl<'a> Download<'a> { // Download the file - let use_rustls = self.process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0"); - if use_rustls == Some(false) { - warn!( - "RUSTUP_USE_RUSTLS is set to `0`; the native-tls backend is deprecated, - please file an issue if the default download backend does not work for your use case" - ); - } - - let tls = match use_rustls { - // If the environment explicitly selects a TLS backend that's unavailable, error out. - #[cfg(not(feature = "reqwest-rustls-tls"))] - Some(true) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature" - )); - } - #[cfg(not(feature = "reqwest-native-tls"))] - Some(false) => { - return Err(anyhow!( - "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature" - )); - } - - // Prefer explicit selections before falling back to the default TLS stack. - #[cfg(feature = "reqwest-native-tls")] - Some(false) => Tls::NativeTls, - - // The default fallback is `rustls`, which should be used whenever available. - #[cfg(feature = "reqwest-rustls-tls")] - _ => Tls::Rustls, - - // The `rustls` feature is disabled, fall back to `native-tls` instead. - #[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))] - _ => Tls::NativeTls, - }; - - let timeout = Duration::from_secs(match self.process.var("RUSTUP_DOWNLOAD_TIMEOUT") { - Ok(s) => NonZero::from_str(&s) - .context( - "invalid value in RUSTUP_DOWNLOAD_TIMEOUT -- must be a natural number greater than zero", - )? - .get(), - Err(_) => 180, - }); - - debug!("downloading with reqwest"); - - let res = tls - .download_to_path(self.url, self.path, self.resume, Some(callback), timeout) + let res = self + .options + .tls + .download_to_path( + self.url, + self.path, + self.resume, + Some(callback), + self.options.timeout, + ) .await; // The notification should only be sent if the download was successful (i.e. didn't timeout) From f4b61afb4ee75c1d8d66bbc7282a9caf0d8b2b8e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 22:26:34 +0200 Subject: [PATCH 10/11] download: attach more functions to Download type --- src/download/mod.rs | 154 ++++++++++++++++-------------------------- src/download/tests.rs | 88 ++++++++++++------------ 2 files changed, 103 insertions(+), 139 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index 6beab9e717..80e6990cca 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -191,17 +191,7 @@ impl<'a> Download<'a> { // Download the file - let res = self - .options - .tls - .download_to_path( - self.url, - self.path, - self.resume, - Some(callback), - self.options.timeout, - ) - .await; + let res = self.download_to_path(Some(callback)).await; // The notification should only be sent if the download was successful (i.e. didn't timeout) if let Some(status) = self.status { @@ -213,50 +203,9 @@ impl<'a> Download<'a> { res } -} - -pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { - match err.downcast_ref::() { - #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - Some(DownloadError::Reqwest(e)) => e.is_timeout() || e.is_connect(), - _ => false, - } -} - -/// User agent header value for HTTP request. -/// See: https://github.com/rust-lang/rustup/issues/2860. -#[cfg(feature = "reqwest-native-tls")] -const REQWEST_DEFAULT_TLS_USER_AGENT: &str = concat!( - "rustup/", - env!("CARGO_PKG_VERSION"), - " (reqwest; default-tls)" -); - -#[cfg(feature = "reqwest-rustls-tls")] -const REQWEST_RUSTLS_TLS_USER_AGENT: &str = - concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)"); - -#[derive(Debug, Copy, Clone)] -enum Tls { - #[cfg(feature = "reqwest-rustls-tls")] - Rustls, - #[cfg(feature = "reqwest-native-tls")] - NativeTls, -} -impl Tls { - async fn download_to_path( - self, - url: &Url, - path: &Path, - resume_from_partial: bool, - callback: Option>, - timeout: Duration, - ) -> anyhow::Result<()> { - let Err(err) = self - .download_impl(url, path, resume_from_partial, callback, timeout) - .await - else { + async fn download_to_path(&self, callback: Option>) -> anyhow::Result<()> { + let Err(err) = self.download_impl(callback).await else { return Ok(()); }; @@ -264,8 +213,9 @@ impl Tls { // if there was a network failure from the client side. // It may be worth looking for other cases where removal is also not desired. Err( - if !(resume_from_partial && is_network_failure(&err)) - && let Err(file_err) = remove_file(path).context("cleaning up cached downloads") + if !(self.resume && is_network_failure(&err)) + && let Err(file_err) = + remove_file(self.path).context("cleaning up cached downloads") { file_err.context(err) } else { @@ -274,17 +224,10 @@ impl Tls { ) } - async fn download_impl( - self, - url: &Url, - path: &Path, - resume_from_partial: bool, - callback: Option>, - timeout: Duration, - ) -> anyhow::Result<()> { - let (file, resume_from) = if resume_from_partial { + async fn download_impl(&self, callback: Option>) -> anyhow::Result<()> { + let (file, resume_from) = if self.resume { // TODO: blocking call - let possible_partial = OpenOptions::new().read(true).open(path); + let possible_partial = OpenOptions::new().read(true).open(self.path); let downloaded_so_far = if let Ok(mut partial) = possible_partial { if let Some(cb) = callback { @@ -315,7 +258,7 @@ impl Tls { .write(true) .create(true) .truncate(false) - .open(path) + .open(self.path) .context("error opening file for download")?; possible_partial.seek(SeekFrom::End(0))?; @@ -327,26 +270,37 @@ impl Tls { .write(true) .create(true) .truncate(true) - .open(path) + .open(self.path) .context("error creating file for download")?, 0, ) }; let file = RefCell::new(file); + let client = match self.options.tls { + #[cfg(feature = "reqwest-rustls-tls")] + Tls::Rustls => rustls_client(self.options.timeout)?, + #[cfg(feature = "reqwest-native-tls")] + Tls::NativeTls => native_tls_client(self.options.timeout)?, + }; // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. - self.download(url, resume_from, timeout, &|event| { - if let Event::DownloadDataReceived(data) = event { - file.borrow_mut() - .write_all(data) - .context("unable to write download to disk")?; - } - match callback { - Some(cb) => cb(event), - None => Ok(()), - } - }) + download( + self.url, + resume_from, + &|event| { + if let Event::DownloadDataReceived(data) = event { + file.borrow_mut() + .write_all(data) + .context("unable to write download to disk")?; + } + match callback { + Some(cb) => cb(event), + None => Ok(()), + } + }, + client, + ) .await?; file.borrow_mut() @@ -355,25 +309,37 @@ impl Tls { Ok::<(), anyhow::Error>(()) } +} - async fn download( - self, - url: &Url, - resume_from: u64, - timeout: Duration, - callback: DownloadCallback<'_>, - ) -> anyhow::Result<()> { - let client = match self { - #[cfg(feature = "reqwest-rustls-tls")] - Self::Rustls => rustls_client(timeout)?, - #[cfg(feature = "reqwest-native-tls")] - Self::NativeTls => native_tls_client(timeout)?, - }; - - download(url, resume_from, callback, client).await +pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { + match err.downcast_ref::() { + #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] + Some(DownloadError::Reqwest(e)) => e.is_timeout() || e.is_connect(), + _ => false, } } +/// User agent header value for HTTP request. +/// See: https://github.com/rust-lang/rustup/issues/2860. +#[cfg(feature = "reqwest-native-tls")] +const REQWEST_DEFAULT_TLS_USER_AGENT: &str = concat!( + "rustup/", + env!("CARGO_PKG_VERSION"), + " (reqwest; default-tls)" +); + +#[cfg(feature = "reqwest-rustls-tls")] +const REQWEST_RUSTLS_TLS_USER_AGENT: &str = + concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)"); + +#[derive(Debug, Copy, Clone)] +enum Tls { + #[cfg(feature = "reqwest-rustls-tls")] + Rustls, + #[cfg(feature = "reqwest-native-tls")] + NativeTls, +} + #[derive(Debug, Copy, Clone)] enum Event<'a> { ResumingPartialDownload, diff --git a/src/download/tests.rs b/src/download/tests.rs index 2c9c7aef45..0375cf9218 100644 --- a/src/download/tests.rs +++ b/src/download/tests.rs @@ -30,7 +30,12 @@ mod reqwest { use url::Url; use super::{scrub_env, serve_file, tmp_dir, write_file}; - use crate::download::{Event, Tls}; + use crate::download::{DownloadOptions, Event, Tls}; + + const OPTIONS: DownloadOptions = DownloadOptions { + tls: DOWNLOAD_BACKEND, + timeout: Duration::from_secs(180), + }; #[cfg(feature = "reqwest-rustls-tls")] const DOWNLOAD_BACKEND: Tls = Tls::Rustls; @@ -110,14 +115,10 @@ mod reqwest { write_file(&target_path, "123"); let from_url = Url::from_file_path(&from_path).unwrap(); - DOWNLOAD_BACKEND - .download_to_path( - &from_url, - &target_path, - true, - None, - Duration::from_secs(180), - ) + OPTIONS + .start(&from_url, &target_path) + .with_resume() + .download_to_path(None) .await .expect("Test download failed"); @@ -139,33 +140,29 @@ mod reqwest { let callback_len = Mutex::new(None); let received_in_callback = Mutex::new(Vec::new()); - DOWNLOAD_BACKEND - .download_to_path( - &from_url, - &target_path, - true, - Some(&|msg| { - match msg { - Event::ResumingPartialDownload => { - assert!(!callback_partial.load(Ordering::SeqCst)); - callback_partial.store(true, Ordering::SeqCst); - } - Event::DownloadContentLengthReceived(len) => { - let mut flag = callback_len.lock().unwrap(); - assert!(flag.is_none()); - *flag = Some(len); - } - Event::DownloadDataReceived(data) => { - for b in data.iter() { - received_in_callback.lock().unwrap().push(*b); - } + OPTIONS + .start(&from_url, &target_path) + .with_resume() + .download_to_path(Some(&|msg| { + match msg { + Event::ResumingPartialDownload => { + assert!(!callback_partial.load(Ordering::SeqCst)); + callback_partial.store(true, Ordering::SeqCst); + } + Event::DownloadContentLengthReceived(len) => { + let mut flag = callback_len.lock().unwrap(); + assert!(flag.is_none()); + *flag = Some(len); + } + Event::DownloadDataReceived(data) => { + for b in data.iter() { + received_in_callback.lock().unwrap().push(*b); } } + } - Ok(()) - }), - Duration::from_secs(180), - ) + Ok(()) + })) .await .expect("Test download failed"); @@ -186,14 +183,10 @@ mod reqwest { let addr = serve_file(b"xxx45".to_vec(), false); let from_url = format!("http://{addr}").parse().unwrap(); - DOWNLOAD_BACKEND - .download_to_path( - &from_url, - &target_path, - true, - None, - Duration::from_secs(180), - ) + OPTIONS + .start(&from_url, &target_path) + .with_resume() + .download_to_path(None) .await .expect_err("download should fail if server ignores range"); @@ -211,10 +204,15 @@ mod reqwest { write_file(&target_path, "123"); let from_url = "http://240.0.0.0:1080".parse().unwrap(); - DOWNLOAD_BACKEND - .download_to_path(&from_url, &target_path, true, None, Duration::from_secs(1)) - .await - .expect_err("download should fail with a connect error"); + DownloadOptions { + tls: DOWNLOAD_BACKEND, + timeout: Duration::from_secs(1), + } + .start(&from_url, &target_path) + .with_resume() + .download_to_path(None) + .await + .expect_err("download should fail with a connect error"); assert!(target_path.exists(), "partial file should not be deleted"); assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "123"); From 5f303440b292527f24eb383841c2cb73b9b909c1 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Sun, 3 May 2026 22:28:17 +0200 Subject: [PATCH 11/11] download: move download() to Download::execute() --- src/download/mod.rs | 138 ++++++++++++++++++++++---------------------- 1 file changed, 69 insertions(+), 69 deletions(-) diff --git a/src/download/mod.rs b/src/download/mod.rs index 80e6990cca..ce3c69f16c 100644 --- a/src/download/mod.rs +++ b/src/download/mod.rs @@ -285,8 +285,7 @@ impl<'a> Download<'a> { }; // TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange. - download( - self.url, + self.execute( resume_from, &|event| { if let Event::DownloadDataReceived(data) = event { @@ -309,6 +308,74 @@ impl<'a> Download<'a> { Ok::<(), anyhow::Error>(()) } + + async fn execute( + &self, + resume_from: u64, + callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, + client: &Client, + ) -> anyhow::Result<()> { + // Short-circuit reqwest for the "file:" URL scheme + // The file scheme is mostly for use by tests to mock the dist server + let url = self.url; + if url.scheme() == "file" { + let src = url + .to_file_path() + .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; + if !src.is_file() { + // Because some of rustup's logic depends on checking + // the error when a downloaded file doesn't exist, make + // the file case return the same error value as the + // network case. + return Err(anyhow!(DownloadError::FileNotFound)); + } + + let mut f = fs::File::open(src).context("unable to open downloaded file")?; + Seek::seek(&mut f, SeekFrom::Start(resume_from))?; + + let mut buffer = vec![0u8; 0x10000]; + loop { + let bytes_read = Read::read(&mut f, &mut buffer)?; + if bytes_read == 0 { + break; + } + callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; + } + + return Ok(()); + } + + let mut req = client.get(url.as_str()); + if resume_from != 0 { + req = req.header(header::RANGE, format!("bytes={resume_from}-")); + } + let res = req + .send() + .await + .map_err(DownloadError::Reqwest) + .context("error downloading file")?; + + // If a download is being resumed, we expect a 206 response; + // otherwise, if the server ignored the range header, + // an error is thrown preemptively to avoid corruption. + let status = res.status().into(); + match (resume_from > 0, status) { + (true, 206) | (false, 200..=299) => {} + _ => return Err(DownloadError::HttpStatus(u32::from(status)).into()), + } + + if let Some(len) = res.content_length() { + let len = len + resume_from; + callback(Event::DownloadContentLengthReceived(len))?; + } + + let mut stream = res.bytes_stream(); + while let Some(item) = stream.next().await { + let bytes = item.map_err(DownloadError::Reqwest)?; + callback(Event::DownloadDataReceived(&bytes))?; + } + Ok(()) + } } pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool { @@ -351,73 +418,6 @@ enum Event<'a> { type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>; -async fn download( - url: &Url, - resume_from: u64, - callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>, - client: &Client, -) -> anyhow::Result<()> { - // Short-circuit reqwest for the "file:" URL scheme - // The file scheme is mostly for use by tests to mock the dist server - if url.scheme() == "file" { - let src = url - .to_file_path() - .map_err(|_| DownloadError::Message(format!("bogus file url: '{url}'")))?; - if !src.is_file() { - // Because some of rustup's logic depends on checking - // the error when a downloaded file doesn't exist, make - // the file case return the same error value as the - // network case. - return Err(anyhow!(DownloadError::FileNotFound)); - } - - let mut f = fs::File::open(src).context("unable to open downloaded file")?; - Seek::seek(&mut f, SeekFrom::Start(resume_from))?; - - let mut buffer = vec![0u8; 0x10000]; - loop { - let bytes_read = Read::read(&mut f, &mut buffer)?; - if bytes_read == 0 { - break; - } - callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?; - } - - return Ok(()); - } - - let mut req = client.get(url.as_str()); - if resume_from != 0 { - req = req.header(header::RANGE, format!("bytes={resume_from}-")); - } - let res = req - .send() - .await - .map_err(DownloadError::Reqwest) - .context("error downloading file")?; - - // If a download is being resumed, we expect a 206 response; - // otherwise, if the server ignored the range header, - // an error is thrown preemptively to avoid corruption. - let status = res.status().into(); - match (resume_from > 0, status) { - (true, 206) | (false, 200..=299) => {} - _ => return Err(DownloadError::HttpStatus(u32::from(status)).into()), - } - - if let Some(len) = res.content_length() { - let len = len + resume_from; - callback(Event::DownloadContentLengthReceived(len))?; - } - - let mut stream = res.bytes_stream(); - while let Some(item) = stream.next().await { - let bytes = item.map_err(DownloadError::Reqwest)?; - callback(Event::DownloadDataReceived(&bytes))?; - } - Ok(()) -} - fn client_generic() -> ClientBuilder { Client::builder() // HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying