diff --git a/CHANGELOG.md b/CHANGELOG.md index a02b7b4..41c8bbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://book.async.rs/overview ## [Unreleased] +### Additions +- `surf::Config::set_mandatory_base_origin()` - require all requests to go to + the origin of the base URL. Useful if the `Config` includes authentication + headers or similar. + ## [2.3.2] - 2021-11-01 ### Fixes diff --git a/src/client.rs b/src/client.rs index 9fba2e7..332abc8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -575,10 +575,25 @@ impl Client { // private function to generate a url based on the base_path fn url(&self, uri: impl AsRef) -> Url { - match &self.config.base_url { + let ret = match &self.config.base_url { None => uri.as_ref().parse().unwrap(), Some(base) => base.join(uri.as_ref()).unwrap(), + }; + if self.config.mandatory_base_origin { + let base_url = self + .config + .base_url + .as_ref() + .expect("Config::set_mandatory_base_origin without base URL"); + if ret.origin() != base_url.origin() { + panic!( + "URL <{}> not relative to mandatory base origin {}", + ret, + base_url.origin().ascii_serialization(), + ); + } } + ret } } @@ -621,4 +636,32 @@ mod client_tests { let url = client.url("posts.json"); assert_eq!(url.as_str(), "http://example.com/api/v1/posts.json"); } + + #[test] + fn mandatory_base_origin_success() { + let base_url = Url::parse("http://example.com/api/v1/").unwrap(); + + let client: Client = Config::new() + .set_base_url(base_url) + .set_mandatory_base_origin() + .try_into() + .unwrap(); + let url = client.url("posts.json"); + assert_eq!(url.as_str(), "http://example.com/api/v1/posts.json"); + let url = client.url("/posts.json"); + assert_eq!(url.as_str(), "http://example.com/posts.json"); + } + + #[test] + #[should_panic] + fn mandatory_base_origin_fail() { + let base_url = Url::parse("http://example.com/api/v1/").unwrap(); + + let client: Client = Config::new() + .set_base_url(base_url) + .set_mandatory_base_origin() + .try_into() + .unwrap(); + let _ = client.url("https://another.example/some/path"); + } } diff --git a/src/config.rs b/src/config.rs index 8bbf5c4..62db737 100644 --- a/src/config.rs +++ b/src/config.rs @@ -36,6 +36,8 @@ pub struct Config { /// Without it, the last path component is considered to be a “file” name /// to be removed to get at the “directory” that is used as the base. pub base_url: Option, + /// Require this configuration to *only* be used for requests to the origin of `base_url`. + pub(crate) mandatory_base_origin: bool, /// Headers to be applied to every request made by this client. pub headers: HashMap, /// Underlying HTTP client config. @@ -110,6 +112,16 @@ impl Config { self } + /// Only allow requests to the same origin as the base URL. + /// + /// Useful if the headers include information that should only get sent to a specific origin. + /// Setting this avoids unintentionally making requests to a different origin using those + /// headers. + pub fn set_mandatory_base_origin(mut self) -> Self { + self.mandatory_base_origin = true; + self + } + /// Set HTTP/1.1 `keep-alive` (connection pooling). /// /// Default: `true`. @@ -223,6 +235,7 @@ impl From for Config { fn from(http_config: HttpConfig) -> Self { Self { base_url: None, + mandatory_base_origin: false, headers: HashMap::new(), http_config, http_client: None,