diff --git a/src/api.rs b/src/api.rs index 567f339..75a1599 100644 --- a/src/api.rs +++ b/src/api.rs @@ -14,8 +14,9 @@ pub struct ApiClient { } impl ApiClient { - /// Create a new API client. Loads config, validates auth. - /// Pass `workspace_id` for endpoints that require it, or `None` for workspace-less endpoints. + /// Create a new API client. Loads config, pre-flights a JWT session. + /// Pass `workspace_id` for endpoints that require it, or `None` for + /// workspace-less endpoints. pub fn new(workspace_id: Option<&str>) -> Self { let profile_config = match config::load("default") { Ok(c) => c, @@ -25,17 +26,27 @@ impl ApiClient { } }; - let api_key = match &profile_config.api_key { - Some(key) if key != "PLACEHOLDER" => key.clone(), - _ => { - eprintln!("error: not authenticated. Run 'hotdata auth login' (or 'hotdata auth') to log in."); + let api_key_fallback = profile_config + .api_key + .as_deref() + .filter(|k| !k.is_empty() && *k != "PLACEHOLDER"); + + // Pre-flight: return the cached JWT if valid, refresh it if + // close to expiry, or mint a new one from the API key. The + // returned string is a JWT — that's what we send on the wire. + let access_token = match crate::jwt::ensure_access_token(&profile_config, api_key_fallback) + { + Ok(t) => t, + Err(e) => { + eprintln!("{}", format!("error: {e}").red()); + eprintln!("Run {} to log in, or pass --api-key.", "hotdata auth".cyan()); std::process::exit(1); } }; Self { client: reqwest::blocking::Client::new(), - api_key, + api_key: access_token, api_url: profile_config.api_url.to_string(), workspace_id: workspace_id.map(String::from), sandbox_id: std::env::var("HOTDATA_SANDBOX").ok().or_else(|| { @@ -60,29 +71,6 @@ impl ApiClient { } } - fn debug_headers(&self) -> Vec<(&str, String)> { - let masked = if self.api_key.len() > 4 { - format!("Bearer ...{}", &self.api_key[self.api_key.len()-4..]) - } else { - "Bearer ***".to_string() - }; - let mut headers = vec![("Authorization", masked)]; - if let Some(ref ws) = self.workspace_id { - headers.push(("X-Workspace-Id", ws.clone())); - } - if let Some(ref sid) = self.sandbox_id { - // Send both headers during the session→sandbox migration window. - headers.push(("X-Session-Id", sid.clone())); - headers.push(("X-Sandbox-Id", sid.clone())); - } - headers - } - - fn log_request(&self, method: &str, url: &str, body: Option<&serde_json::Value>) { - let headers = self.debug_headers(); - let header_refs: Vec<(&str, &str)> = headers.iter().map(|(k, v)| (*k, v.as_str())).collect(); - util::debug_request(method, url, &header_refs, body); - } /// Prints an error for a non-2xx response and exits. On 4xx, first re-probes /// the API key: if it's actually invalid, a clear re-auth hint is shown @@ -111,29 +99,25 @@ impl ApiClient { req } - /// GET request with query parameters, returns parsed response. - /// Parameters with `None` values are omitted. - pub fn get_with_params(&self, path: &str, params: &[(&str, Option)]) -> T { - let filtered: Vec<(&str, &String)> = params.iter() - .filter_map(|(k, v)| v.as_ref().map(|val| (*k, val))) - .collect(); - let url = format!("{}{path}", self.api_url); - self.log_request("GET", &url, None); - - let resp = match self.build_request(reqwest::Method::GET, &url).query(&filtered).send() { - Ok(r) => r, + /// Send via `util::send_debug` and unwrap connection errors with the + /// CLI's standard "error connecting" exit. All public HTTP methods + /// route through here so debug logging is uniform. + fn send( + &self, + builder: reqwest::blocking::RequestBuilder, + body_for_log: Option<&serde_json::Value>, + ) -> (reqwest::StatusCode, String) { + match util::send_debug(&self.client, builder, body_for_log) { + Ok(pair) => pair, Err(e) => { eprintln!("error connecting to API: {e}"); std::process::exit(1); } - }; - - let (status, body) = util::debug_response(resp); - if !status.is_success() { - self.fail_response(status, body); } + } - match serde_json::from_str(&body) { + fn parse_json(body: &str) -> T { + match serde_json::from_str(body) { Ok(v) => v, Err(e) => { eprintln!("error parsing response: {e}"); @@ -142,92 +126,56 @@ impl ApiClient { } } - /// GET request, returns parsed response. - pub fn get(&self, path: &str) -> T { + /// GET request with query parameters, returns parsed response. + /// Parameters with `None` values are omitted. + pub fn get_with_params(&self, path: &str, params: &[(&str, Option)]) -> T { + let filtered: Vec<(&str, &String)> = params.iter() + .filter_map(|(k, v)| v.as_ref().map(|val| (*k, val))) + .collect(); let url = format!("{}{path}", self.api_url); - self.log_request("GET", &url, None); - - let resp = match self.build_request(reqwest::Method::GET, &url).send() { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - let (status, body) = util::debug_response(resp); + let req = self.build_request(reqwest::Method::GET, &url).query(&filtered); + let (status, body) = self.send(req, None); if !status.is_success() { self.fail_response(status, body); } + Self::parse_json(&body) + } - match serde_json::from_str(&body) { - Ok(v) => v, - Err(e) => { - eprintln!("error parsing response: {e}"); - std::process::exit(1); - } + /// GET request, returns parsed response. + pub fn get(&self, path: &str) -> T { + let url = format!("{}{path}", self.api_url); + let req = self.build_request(reqwest::Method::GET, &url); + let (status, body) = self.send(req, None); + if !status.is_success() { + self.fail_response(status, body); } + Self::parse_json(&body) } /// GET request; returns `None` on HTTP 404. Other status codes use the same handling as /// [`Self::get`]. Used when probing many paths where a missing resource is normal. pub fn get_none_if_not_found(&self, path: &str) -> Option { let url = format!("{}{path}", self.api_url); - self.log_request("GET", &url, None); - - let resp = match self.build_request(reqwest::Method::GET, &url).send() { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - let (status, body) = util::debug_response(resp); + let req = self.build_request(reqwest::Method::GET, &url); + let (status, body) = self.send(req, None); if status == reqwest::StatusCode::NOT_FOUND { return None; } if !status.is_success() { self.fail_response(status, body); } - - match serde_json::from_str(&body) { - Ok(v) => Some(v), - Err(e) => { - eprintln!("error parsing response: {e}"); - std::process::exit(1); - } - } + Some(Self::parse_json(&body)) } /// POST request with JSON body, returns parsed response. pub fn post(&self, path: &str, body: &serde_json::Value) -> T { let url = format!("{}{path}", self.api_url); - self.log_request("POST", &url, Some(body)); - - let resp = match self.build_request(reqwest::Method::POST, &url) - .json(body) - .send() - { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - let (status, resp_body) = util::debug_response(resp); + let req = self.build_request(reqwest::Method::POST, &url).json(body); + let (status, resp_body) = self.send(req, Some(body)); if !status.is_success() { self.fail_response(status, resp_body); } - - match serde_json::from_str(&resp_body) { - Ok(v) => v, - Err(e) => { - eprintln!("error parsing response: {e}"); - std::process::exit(1); - } - } + Self::parse_json(&resp_body) } /// GET request, exits only on connection error, returns raw (status, body). @@ -235,66 +183,26 @@ impl ApiClient { /// to handle non-2xx responses gracefully instead of aborting. pub fn get_raw(&self, path: &str) -> (reqwest::StatusCode, String) { let url = format!("{}{path}", self.api_url); - self.log_request("GET", &url, None); - - let resp = match self.build_request(reqwest::Method::GET, &url).send() { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - util::debug_response(resp) + let req = self.build_request(reqwest::Method::GET, &url); + self.send(req, None) } /// POST request with JSON body, exits on error, returns raw (status, body). pub fn post_raw(&self, path: &str, body: &serde_json::Value) -> (reqwest::StatusCode, String) { let url = format!("{}{path}", self.api_url); - self.log_request("POST", &url, Some(body)); - - let resp = match self.build_request(reqwest::Method::POST, &url) - .json(body) - .send() - { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - util::debug_response(resp) + let req = self.build_request(reqwest::Method::POST, &url).json(body); + self.send(req, Some(body)) } /// PATCH request with JSON body, returns parsed response. pub fn patch(&self, path: &str, body: &serde_json::Value) -> T { let url = format!("{}{path}", self.api_url); - self.log_request("PATCH", &url, Some(body)); - - let resp = match self.build_request(reqwest::Method::PATCH, &url) - .json(body) - .send() - { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - let (status, resp_body) = util::debug_response(resp); + let req = self.build_request(reqwest::Method::PATCH, &url).json(body); + let (status, resp_body) = self.send(req, Some(body)); if !status.is_success() { self.fail_response(status, resp_body); } - - match serde_json::from_str(&resp_body) { - Ok(v) => v, - Err(e) => { - eprintln!("error parsing response: {e}"); - std::process::exit(1); - } - } + Self::parse_json(&resp_body) } /// POST with a custom request body (for file uploads). Returns raw status and body. @@ -306,24 +214,16 @@ impl ApiClient { content_length: Option, ) -> (reqwest::StatusCode, String) { let url = format!("{}{path}", self.api_url); - self.log_request("POST", &url, None); - let mut req = self.build_request(reqwest::Method::POST, &url) .header("Content-Type", content_type); - if let Some(len) = content_length { req = req.header("Content-Length", len); } - - let resp = match req.body(reqwest::blocking::Body::new(reader)).send() { - Ok(r) => r, - Err(e) => { - eprintln!("error connecting to API: {e}"); - std::process::exit(1); - } - }; - - util::debug_response(resp) + let req = req.body(reqwest::blocking::Body::new(reader)); + // Body is an opaque stream — nothing meaningful to print under + // --debug, so pass `None`. Headers (including the masked + // Authorization) still log. + self.send(req, None) } } diff --git a/src/auth.rs b/src/auth.rs index 25483d5..72d0383 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -9,7 +9,8 @@ use std::collections::HashMap; use std::io::stdout; pub fn logout(profile: &str) { - if let Err(e) = config::remove_api_key(profile) { + crate::jwt::clear_session(); + if let Err(e) = config::clear_workspaces(profile) { eprintln!("error: {e}"); std::process::exit(1); } @@ -25,21 +26,31 @@ pub enum AuthStatus { } pub fn check_status(profile_config: &config::ProfileConfig) -> AuthStatus { - let api_key = match &profile_config.api_key { - Some(key) if key != "PLACEHOLDER" => key.clone(), - _ => return AuthStatus::NotConfigured, + let api_key_fallback = profile_config + .api_key + .as_deref() + .filter(|k| !k.is_empty() && *k != "PLACEHOLDER"); + + // PKCE-origin sessions don't write an api_key, so absence of a key + // alone isn't "not configured" — only true if there's also no + // cached JWT session to validate. + if api_key_fallback.is_none() && crate::jwt::load_session().is_none() { + return AuthStatus::NotConfigured; + } + + let access_token = match crate::jwt::ensure_access_token(profile_config, api_key_fallback) { + Ok(t) => t, + Err(_) => return AuthStatus::Invalid(401), }; let url = format!("{}/workspaces", profile_config.api_url); let client = reqwest::blocking::Client::new(); - - match client + let req = client .get(&url) - .header("Authorization", format!("Bearer {api_key}")) - .send() - { - Ok(resp) if resp.status().is_success() => AuthStatus::Authenticated, - Ok(resp) => AuthStatus::Invalid(resp.status().as_u16()), + .header("Authorization", format!("Bearer {access_token}")); + match crate::util::send_debug(&client, req, None) { + Ok((status, _)) if status.is_success() => AuthStatus::Authenticated, + Ok((status, _)) => AuthStatus::Invalid(status.as_u16()), Err(e) => AuthStatus::ConnectionError(e.to_string()), } } @@ -53,21 +64,42 @@ pub fn status(profile: &str) { } }; - let source_label = if profile_config.api_key_source == ApiKeySource::Env { - " (env override)" - } else { - "" + // The credential the CLI is *about to use*. Note: even when an + // override is set, the wire credential is still a JWT (minted on + // demand from the override) — but we report the user-visible source. + let method_label = match profile_config.api_key_source { + ApiKeySource::Flag => "API Key flag", + ApiKeySource::Env => "API Key env", + ApiKeySource::Config => "CLI Session", + }; + + // For Flag/Env we mask the api_key the user supplied. For the + // CLI session path we mask the refresh_token — it's stable across + // commands (unlike the 5-min access_token), so the tail stays + // recognizable between runs. + let credential_tail = match profile_config.api_key_source { + ApiKeySource::Flag | ApiKeySource::Env => profile_config + .api_key + .as_deref() + .map(crate::util::mask_credential), + ApiKeySource::Config => crate::jwt::load_session() + .map(|s| crate::util::mask_credential(&s.refresh_token)), + }; + let method_suffix = match credential_tail { + Some(tail) => format!(" - {method_label} [{tail}]"), + None => format!(" - {method_label}"), }; match check_status(&profile_config) { AuthStatus::NotConfigured => { print_row("Authenticated", &"No".red().to_string()); - print_row("API Key", &"Not configured".red().to_string()); } AuthStatus::Authenticated => { print_row("API URL", &profile_config.api_url.cyan().to_string()); - print_row("Authenticated", &"Yes".green().to_string()); - print_row("API Key", &format!("{}{source_label}", "Valid".green())); + print_row( + "Authenticated", + &format!("{}{}", "Yes".green(), method_suffix.dark_grey()), + ); match profile_config.workspaces.first() { Some(w) => { print_row("Workspace", &format!("{} {}", w.name.as_str().cyan(), format!("({})", w.public_id).dark_grey())); @@ -76,15 +108,11 @@ pub fn status(profile: &str) { None => print_row("Current Workspace", &"None".dark_grey().to_string()), } } - AuthStatus::Invalid(code) => { + AuthStatus::Invalid(_) => { print_row("API URL", &profile_config.api_url.cyan().to_string()); - print_row("Authenticated", &"No".red().to_string()); print_row( - "API Key", - &format!( - "{}{source_label}", - format!("Invalid (HTTP {})", code).red() - ), + "Authenticated", + &format!("{}{}", "No".red(), method_suffix.dark_grey()), ); } AuthStatus::ConnectionError(e) => { @@ -94,76 +122,12 @@ pub fn status(profile: &str) { } } -#[derive(Debug, PartialEq)] -pub enum LoginResult { - Success { token: String, workspace: Option }, - Forbidden, - Failed(String), - ConnectionError(String), -} - -#[derive(Deserialize)] -struct TokenResponse { - token: String, -} - #[derive(Deserialize)] struct WsListResponse { workspaces: Vec } #[derive(Deserialize)] struct WsItem { public_id: String, name: String } -/// Exchange an authorization code + PKCE verifier for an API token, -/// then fetch available workspaces. -fn exchange_and_save_token(api_url: &str, code: &str, code_verifier: &str) -> LoginResult { - let token_url = format!("{api_url}/auth/token"); - let client = reqwest::blocking::Client::new(); - - let resp = match client - .post(&token_url) - .json(&serde_json::json!({ "code": code, "code_verifier": code_verifier })) - .send() - { - Ok(r) => r, - Err(e) => return LoginResult::ConnectionError(e.to_string()), - }; - - if resp.status() == reqwest::StatusCode::FORBIDDEN { - return LoginResult::Forbidden; - } - - if !resp.status().is_success() { - return LoginResult::Failed(format!("HTTP {}", resp.status())); - } - - let body: TokenResponse = match resp.json() { - Ok(b) => b, - Err(e) => return LoginResult::Failed(format!("error parsing token response: {e}")), - }; - - // Save the token - if let Err(e) = config::save_api_key("default", &body.token) { - return LoginResult::Failed(format!("error saving token: {e}")); - } - - // Fetch and cache workspaces - let ws_url = format!("{api_url}/workspaces"); - let default_workspace = if let Ok(r) = client.get(&ws_url).header("Authorization", format!("Bearer {}", body.token)).send() { - if r.status().is_success() { - if let Ok(ws) = r.json::() { - let entries: Vec = ws.workspaces.into_iter() - .map(|w| config::WorkspaceEntry { public_id: w.public_id, name: w.name }) - .collect(); - let first = entries.first().cloned(); - let _ = config::save_workspaces("default", entries); - first - } else { None } - } else { None } - } else { None }; - - LoginResult::Success { token: body.token, workspace: default_workspace } -} - /// Wait for the browser callback, verify state, and extract the authorization code. fn receive_callback(server: &tiny_http::Server, expected_state: &str) -> Result { let request = server.recv().map_err(|e| format!("failed to receive callback: {e}"))?; @@ -252,7 +216,6 @@ fn is_already_signed_in(profile_config: &config::ProfileConfig) -> bool { pub fn login() { let profile_config = config::load("default").unwrap_or_default(); - let api_url = profile_config.api_url.to_string(); let app_url = profile_config.app_url.to_string(); // Check if already authenticated @@ -272,13 +235,27 @@ pub fn login() { let code_challenge = generate_code_challenge(&code_verifier); let state = generate_random_string(32); - // Bind to port 0 so the OS picks an available port + // Bind to port 0 so the OS picks an available port. DOT's consent + // page will redirect here with `?code=...&state=...`. let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to start local callback server"); let port = server.server_addr().to_ip().unwrap().port(); + let redirect_uri = format!("http://127.0.0.1:{port}/"); + // DOT's `/o/authorize/` endpoint is mounted off the app URL (the + // browser-facing one; allauth session cookies live here). We send + // no `scope` parameter — the consent page picks permissions and + // workspace scope interactively, then composes the scope string + // server-side (see HotdataAllowForm). let login_url = format!( - "{app_url}/auth/cli-login?code_challenge={code_challenge}&code_challenge_method=S256&state={state}&callback_port={port}" + "{app_url}/o/authorize/\ + ?client_id=hotdata-cli\ + &response_type=code\ + &redirect_uri={redirect_uri}\ + &code_challenge={code_challenge}\ + &code_challenge_method=S256\ + &state={state}", + app_url = app_url.trim_end_matches('/'), ); println!("Opening browser to log in..."); @@ -306,8 +283,11 @@ pub fn login() { } }; - match exchange_and_save_token(&api_url, &code, &code_verifier) { - LoginResult::Success { workspace, .. } => { + match crate::jwt::mint_from_pkce_code(&profile_config, &code, &code_verifier, &redirect_uri) { + Ok(session) => { + if let Err(e) = crate::jwt::save_session(&session) { + eprintln!("warning: could not save session: {e}"); + } stdout() .execute(SetForegroundColor(Color::Green)) .unwrap() @@ -316,7 +296,11 @@ pub fn login() { .execute(ResetColor) .unwrap(); - match workspace { + // Best-effort workspace cache using the freshly minted JWT. + // Fall back to the existing on-disk list if the fetch fails. + let workspaces = cache_workspaces(&profile_config, &session.access_token) + .unwrap_or(profile_config.workspaces); + match workspaces.first() { Some(w) => { print_row("Workspace", &format!("{} {}", w.name.as_str().cyan(), format!("({})", w.public_id).dark_grey())); print_row("", &"use 'hotdata workspaces set' to switch workspaces".dark_grey().to_string()); @@ -324,21 +308,42 @@ pub fn login() { None => print_row("Workspace", &"None".dark_grey().to_string()), } } - LoginResult::Forbidden => { - eprintln!("{}", "You are not authorized to create a new API token.".red()); - std::process::exit(1); - } - LoginResult::Failed(msg) => { - eprintln!("token exchange failed: {msg}"); - std::process::exit(1); - } - LoginResult::ConnectionError(e) => { - eprintln!("error connecting to API: {e}"); + Err(msg) => { + eprintln!("{}", msg.red()); std::process::exit(1); } } } +/// Fetch workspaces with a freshly minted JWT and cache them in config. +/// Returns the freshly fetched list so callers can display it without +/// having to reload config from disk. +fn cache_workspaces( + profile: &config::ProfileConfig, + access_token: &str, +) -> Result, String> { + let url = format!("{}/workspaces", profile.api_url); + let client = reqwest::blocking::Client::new(); + let req = client + .get(&url) + .header("Authorization", format!("Bearer {access_token}")); + let (status, body) = crate::util::send_debug(&client, req, None).map_err(|e| format!("{e}"))?; + if !status.is_success() { + return Err(format!("HTTP {status}")); + } + let ws: WsListResponse = serde_json::from_str(&body).map_err(|e| format!("{e}"))?; + let entries: Vec = ws + .workspaces + .into_iter() + .map(|w| config::WorkspaceEntry { + public_id: w.public_id, + name: w.name, + }) + .collect(); + config::save_workspaces("default", entries.clone())?; + Ok(entries) +} + fn generate_code_verifier() -> String { generate_random_string(64) } @@ -370,251 +375,199 @@ fn parse_query_params(url: &str) -> HashMap { #[cfg(test)] mod tests { use super::*; - use config::{ApiUrl, ProfileConfig, test_helpers::with_temp_config_dir}; + use config::{ApiUrl, AppUrl, ProfileConfig, test_helpers::with_temp_config_dir}; - fn mock_profile(api_url: &str, api_key: Option<&str>) -> ProfileConfig { + fn mock_profile(url: &str, api_key: Option<&str>) -> ProfileConfig { ProfileConfig { api_key: api_key.map(String::from), - api_url: ApiUrl(Some(api_url.to_string())), + api_url: ApiUrl(Some(url.to_string())), + // Point app_url at the same server so any oauth path (e.g. + // ensure_access_token minting from an api_key) hits the + // mock instead of the real production app. + app_url: AppUrl(Some(url.to_string())), ..Default::default() } } + /// Persist a fully-valid session so check_status can short-circuit + /// the JWT mint/refresh path and go straight to the /workspaces + /// probe — mirrors the on-disk state immediately after a PKCE login. + fn save_test_session(token: &str) { + use std::time::{SystemTime, UNIX_EPOCH}; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + crate::jwt::save_session(&crate::jwt::Session { + access_token: token.to_string(), + access_expires_at: now + 3600, + refresh_token: "r".into(), + refresh_expires_at: now + 86400, + source: "pkce".into(), + }) + .unwrap(); + } + // --- check_status tests --- #[test] - fn status_not_configured_when_no_key() { + fn status_not_configured_when_no_key_no_session() { + let (_tmp, _guard) = with_temp_config_dir(); let profile = mock_profile("http://localhost", None); assert_eq!(check_status(&profile), AuthStatus::NotConfigured); } #[test] - fn status_not_configured_when_placeholder() { + fn status_not_configured_when_placeholder_no_session() { + let (_tmp, _guard) = with_temp_config_dir(); let profile = mock_profile("http://localhost", Some("PLACEHOLDER")); assert_eq!(check_status(&profile), AuthStatus::NotConfigured); } #[test] - fn status_authenticated_with_valid_key() { + fn status_authenticated_with_valid_session() { + let (_tmp, _guard) = with_temp_config_dir(); + save_test_session("valid-jwt"); let mut server = mockito::Server::new(); let mock = server .mock("GET", "/workspaces") - .match_header("Authorization", "Bearer valid-key") + .match_header("Authorization", "Bearer valid-jwt") .with_status(200) .with_body(r#"{"workspaces":[]}"#) .create(); - let profile = mock_profile(&server.url(), Some("valid-key")); + let profile = mock_profile(&server.url(), None); assert_eq!(check_status(&profile), AuthStatus::Authenticated); mock.assert(); } #[test] - fn status_invalid_with_bad_key() { + fn status_authenticated_via_api_token_fallback_when_no_session() { + // Realistic upgrade path: user has an api_key in config but no + // session.json yet. ensure_access_token must mint a JWT from + // the api_key, then check_status probes /workspaces with it. + let (_tmp, _guard) = with_temp_config_dir(); let mut server = mockito::Server::new(); - let mock = server + let mint_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "api_token".into(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"minted-jwt","expires_in":300,"refresh_token":"r"}"#) + .create(); + let probe_mock = server .mock("GET", "/workspaces") - .with_status(401) + .match_header("Authorization", "Bearer minted-jwt") + .with_status(200) + .with_body(r#"{"workspaces":[]}"#) .create(); - let profile = mock_profile(&server.url(), Some("bad-key")); - assert_eq!(check_status(&profile), AuthStatus::Invalid(401)); - mock.assert(); + let profile = mock_profile(&server.url(), Some("hd_xyz")); + assert_eq!(check_status(&profile), AuthStatus::Authenticated); + mint_mock.assert(); + probe_mock.assert(); } #[test] - fn status_invalid_with_forbidden() { + fn status_invalid_when_session_revoked_server_side() { + let (_tmp, _guard) = with_temp_config_dir(); + save_test_session("revoked-jwt"); let mut server = mockito::Server::new(); let mock = server .mock("GET", "/workspaces") - .with_status(403) + .with_status(401) .create(); - let profile = mock_profile(&server.url(), Some("forbidden-key")); - assert_eq!(check_status(&profile), AuthStatus::Invalid(403)); + let profile = mock_profile(&server.url(), None); + assert_eq!(check_status(&profile), AuthStatus::Invalid(401)); mock.assert(); } #[test] - fn status_connection_error() { - let profile = mock_profile("http://127.0.0.1:1", Some("key")); - match check_status(&profile) { - AuthStatus::ConnectionError(_) => {} - other => panic!("expected ConnectionError, got {:?}", other), - } - } - - // --- is_already_signed_in tests --- - - #[test] - fn already_signed_in_when_key_valid() { + fn status_invalid_with_forbidden() { + let (_tmp, _guard) = with_temp_config_dir(); + save_test_session("jwt"); let mut server = mockito::Server::new(); let mock = server .mock("GET", "/workspaces") - .match_header("Authorization", "Bearer existing-key") - .with_status(200) - .with_body(r#"{"workspaces":[]}"#) + .with_status(403) .create(); - let profile = mock_profile(&server.url(), Some("existing-key")); - assert!(is_already_signed_in(&profile)); + let profile = mock_profile(&server.url(), None); + assert_eq!(check_status(&profile), AuthStatus::Invalid(403)); mock.assert(); } #[test] - fn not_signed_in_when_no_key() { - let profile = mock_profile("http://localhost", None); - assert!(!is_already_signed_in(&profile)); - } - - #[test] - fn not_signed_in_when_key_invalid() { + fn status_invalid_when_api_token_rejected_no_session() { + // No session, and the api_key fallback is rejected by the mint + // endpoint — collapse to Invalid(401) so `auth status` shows + // the user they need to re-auth. + let (_tmp, _guard) = with_temp_config_dir(); let mut server = mockito::Server::new(); let mock = server - .mock("GET", "/workspaces") + .mock("POST", "/o/token/") .with_status(401) .create(); - let profile = mock_profile(&server.url(), Some("expired-key")); - assert!(!is_already_signed_in(&profile)); + let profile = mock_profile(&server.url(), Some("hd_revoked")); + assert_eq!(check_status(&profile), AuthStatus::Invalid(401)); mock.assert(); } - // --- exchange_and_save_token tests --- - #[test] - fn exchange_and_save_token_success() { + fn status_connection_error_during_probe() { let (_tmp, _guard) = with_temp_config_dir(); - let mut server = mockito::Server::new(); - - let token_mock = server - .mock("POST", "/auth/token") - .with_status(200) - .with_header("content-type", "application/json") - .with_body(r#"{"token":"new-api-token-xyz"}"#) - .create(); - - let ws_mock = server - .mock("GET", "/workspaces") - .match_header("Authorization", "Bearer new-api-token-xyz") - .with_status(200) - .with_header("content-type", "application/json") - .with_body(r#"{"workspaces":[{"public_id":"ws-123","name":"My Workspace"}]}"#) - .create(); - - let result = exchange_and_save_token(&server.url(), "auth-code", "verifier"); - - token_mock.assert(); - ws_mock.assert(); - - match result { - LoginResult::Success { token, workspace } => { - assert_eq!(token, "new-api-token-xyz"); - let ws = workspace.expect("should have a workspace"); - assert_eq!(ws.public_id, "ws-123"); - assert_eq!(ws.name, "My Workspace"); - } - other => panic!("expected Success, got {:?}", other), + save_test_session("jwt"); + let profile = mock_profile("http://127.0.0.1:1", None); + match check_status(&profile) { + AuthStatus::ConnectionError(_) => {} + other => panic!("expected ConnectionError, got {:?}", other), } - - // Verify token was saved to config - let profile = config::load("default").unwrap(); - assert_eq!(profile.api_key, Some("new-api-token-xyz".to_string())); } + // --- is_already_signed_in tests --- + #[test] - fn exchange_and_save_token_success_no_workspaces() { + fn already_signed_in_when_session_valid() { let (_tmp, _guard) = with_temp_config_dir(); + save_test_session("session-jwt"); let mut server = mockito::Server::new(); - - let token_mock = server - .mock("POST", "/auth/token") - .with_status(200) - .with_header("content-type", "application/json") - .with_body(r#"{"token":"token-no-ws"}"#) - .create(); - - let ws_mock = server + let mock = server .mock("GET", "/workspaces") + .match_header("Authorization", "Bearer session-jwt") .with_status(200) - .with_header("content-type", "application/json") .with_body(r#"{"workspaces":[]}"#) .create(); - let result = exchange_and_save_token(&server.url(), "code", "verifier"); - - token_mock.assert(); - ws_mock.assert(); - - match result { - LoginResult::Success { token, workspace } => { - assert_eq!(token, "token-no-ws"); - assert!(workspace.is_none()); - } - other => panic!("expected Success, got {:?}", other), - } - } - - #[test] - fn exchange_and_save_token_forbidden() { - let (_tmp, _guard) = with_temp_config_dir(); - let mut server = mockito::Server::new(); - - let mock = server - .mock("POST", "/auth/token") - .with_status(403) - .create(); - - let result = exchange_and_save_token(&server.url(), "code", "verifier"); + let profile = mock_profile(&server.url(), None); + assert!(is_already_signed_in(&profile)); mock.assert(); - assert_eq!(result, LoginResult::Forbidden); } #[test] - fn exchange_and_save_token_unauthorized() { + fn not_signed_in_when_no_key_no_session() { let (_tmp, _guard) = with_temp_config_dir(); - let mut server = mockito::Server::new(); - - let mock = server - .mock("POST", "/auth/token") - .with_status(401) - .create(); - - let result = exchange_and_save_token(&server.url(), "code", "verifier"); - mock.assert(); - match result { - LoginResult::Failed(msg) => assert!(msg.contains("401")), - other => panic!("expected Failed, got {:?}", other), - } + let profile = mock_profile("http://localhost", None); + assert!(!is_already_signed_in(&profile)); } #[test] - fn exchange_and_save_token_server_error() { + fn not_signed_in_when_session_invalid() { let (_tmp, _guard) = with_temp_config_dir(); + save_test_session("expired-jwt"); let mut server = mockito::Server::new(); - let mock = server - .mock("POST", "/auth/token") - .with_status(500) + .mock("GET", "/workspaces") + .with_status(401) .create(); - let result = exchange_and_save_token(&server.url(), "code", "verifier"); + let profile = mock_profile(&server.url(), None); + assert!(!is_already_signed_in(&profile)); mock.assert(); - match result { - LoginResult::Failed(msg) => assert!(msg.contains("500")), - other => panic!("expected Failed, got {:?}", other), - } - } - - #[test] - fn exchange_and_save_token_connection_error() { - let (_tmp, _guard) = with_temp_config_dir(); - - let result = exchange_and_save_token("http://127.0.0.1:1", "code", "verifier"); - match result { - LoginResult::ConnectionError(_) => {} - other => panic!("expected ConnectionError, got {:?}", other), - } } // --- receive_callback tests --- @@ -662,6 +615,69 @@ mod tests { assert!(result.unwrap_err().contains("state mismatch")); } + // --- cache_workspaces tests --- + + #[test] + fn cache_workspaces_persists_to_config() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("GET", "/workspaces") + .match_header("Authorization", "Bearer jwt-xyz") + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{"workspaces":[{"public_id":"ws-1","name":"My WS"},{"public_id":"ws-2","name":"Other"}]}"#, + ) + .create(); + + let profile = mock_profile(&server.url(), None); + let entries = cache_workspaces(&profile, "jwt-xyz").unwrap(); + m.assert(); + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].public_id, "ws-1"); + assert_eq!(entries[0].name, "My WS"); + + // Reload from disk and confirm the cache survived. + let loaded = config::load("default").unwrap(); + assert_eq!(loaded.workspaces.len(), 2); + assert_eq!(loaded.workspaces[1].public_id, "ws-2"); + } + + #[test] + fn cache_workspaces_empty_list() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("GET", "/workspaces") + .with_status(200) + .with_body(r#"{"workspaces":[]}"#) + .create(); + + let profile = mock_profile(&server.url(), None); + let entries = cache_workspaces(&profile, "jwt").unwrap(); + m.assert(); + assert!(entries.is_empty()); + } + + #[test] + fn cache_workspaces_http_error() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server.mock("GET", "/workspaces").with_status(500).create(); + let profile = mock_profile(&server.url(), None); + let err = cache_workspaces(&profile, "jwt").unwrap_err(); + m.assert(); + assert!(err.contains("500"), "got: {err}"); + } + + #[test] + fn cache_workspaces_connection_error() { + let (_tmp, _guard) = with_temp_config_dir(); + let profile = mock_profile("http://127.0.0.1:1", None); + assert!(cache_workspaces(&profile, "jwt").is_err()); + } + #[test] fn receive_callback_no_code() { let server = tiny_http::Server::http("127.0.0.1:0").unwrap(); diff --git a/src/config.rs b/src/config.rs index d4355f9..421d275 100644 --- a/src/config.rs +++ b/src/config.rs @@ -31,7 +31,7 @@ pub struct WorkspaceEntry { } #[derive(Debug, Clone, Serialize)] -pub struct AppUrl(Option); +pub struct AppUrl(pub(crate) Option); impl Default for AppUrl { fn default() -> Self { @@ -98,6 +98,10 @@ impl<'de> Deserialize<'de> for ApiUrl { #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct ProfileConfig { + // Transient only: populated from `--api-key` and `HOTDATA_API_KEY`, + // never persisted to or read from YAML. Auth state on disk lives + // entirely in session.json. + #[serde(skip)] pub api_key: Option, #[serde(skip)] pub api_url: ApiUrl, @@ -123,32 +127,10 @@ fn write_config(config_path: &std::path::Path, content: &str) -> Result<(), Stri fs::write(config_path, content).map_err(|e| format!("error writing config file: {e}")) } -pub fn save_api_key(profile: &str, api_key: &str) -> Result<(), String> { - let config_path = config_path()?; - - let mut config_file: ConfigFile = if config_path.exists() { - let content = fs::read_to_string(&config_path) - .map_err(|e| format!("error reading config file: {e}"))?; - serde_yaml::from_str(&content).map_err(|e| format!("error parsing config file: {e}"))? - } else { - ConfigFile { - profiles: HashMap::new(), - } - }; - - config_file - .profiles - .entry(profile.to_string()) - .or_default() - .api_key = Some(api_key.to_string()); - - let content = serde_yaml::to_string(&config_file) - .map_err(|e| format!("error serializing config: {e}"))?; - - write_config(&config_path, &content) -} - -pub fn remove_api_key(profile: &str) -> Result<(), String> { +/// Wipe the workspace cache for a profile. Paired with +/// `jwt::clear_session()` in `auth::logout` — together they reset the +/// on-disk state that login populates. +pub fn clear_workspaces(profile: &str) -> Result<(), String> { let config_path = config_path()?; if !config_path.exists() { @@ -161,7 +143,6 @@ pub fn remove_api_key(profile: &str) -> Result<(), String> { serde_yaml::from_str(&content).map_err(|e| format!("error parsing config file: {e}"))?; if let Some(entry) = config_file.profiles.get_mut(profile) { - entry.api_key = None; entry.workspaces.clear(); } @@ -339,66 +320,42 @@ mod tests { use super::*; use super::test_helpers::with_temp_config_dir; - #[test] - fn save_and_load_api_key() { - let (_tmp, _guard) = with_temp_config_dir(); - - save_api_key("default", "test-key-123").unwrap(); - let profile = load("default").unwrap(); - assert_eq!(profile.api_key, Some("test-key-123".to_string())); + fn ws(id: &str, name: &str) -> WorkspaceEntry { + WorkspaceEntry { public_id: id.into(), name: name.into() } } #[test] - fn save_api_key_creates_config_dir() { + fn save_workspaces_creates_config_dir() { let (_tmp, _guard) = with_temp_config_dir(); - // Config file shouldn't exist yet let path = config_path().unwrap(); assert!(!path.exists()); - save_api_key("default", "key").unwrap(); + save_workspaces("default", vec![ws("ws-1", "WS")]).unwrap(); assert!(path.exists()); } #[test] - fn remove_api_key_clears_key_and_workspaces() { + fn clear_workspaces_empties_the_list() { let (_tmp, _guard) = with_temp_config_dir(); + save_workspaces("default", vec![ws("ws-1", "Test WS")]).unwrap(); - save_api_key("default", "key-to-remove").unwrap(); - save_workspaces( - "default", - vec![WorkspaceEntry { - public_id: "ws-1".into(), - name: "Test WS".into(), - }], - ) - .unwrap(); - - remove_api_key("default").unwrap(); + clear_workspaces("default").unwrap(); let profile = load("default").unwrap(); - assert_eq!(profile.api_key, None); assert!(profile.workspaces.is_empty()); } #[test] - fn remove_api_key_noop_when_no_config() { + fn clear_workspaces_noop_when_no_config() { let (_tmp, _guard) = with_temp_config_dir(); - - // Should not error when config file doesn't exist - assert!(remove_api_key("default").is_ok()); + assert!(clear_workspaces("default").is_ok()); } #[test] fn save_and_load_workspaces() { let (_tmp, _guard) = with_temp_config_dir(); - - save_api_key("default", "key").unwrap(); - let workspaces = vec![ - WorkspaceEntry { public_id: "ws-1".into(), name: "First".into() }, - WorkspaceEntry { public_id: "ws-2".into(), name: "Second".into() }, - ]; - save_workspaces("default", workspaces).unwrap(); + save_workspaces("default", vec![ws("ws-1", "First"), ws("ws-2", "Second")]).unwrap(); let profile = load("default").unwrap(); assert_eq!(profile.workspaces.len(), 2); @@ -409,20 +366,10 @@ mod tests { #[test] fn save_default_workspace_moves_to_front() { let (_tmp, _guard) = with_temp_config_dir(); - - save_api_key("default", "key").unwrap(); - let workspaces = vec![ - WorkspaceEntry { public_id: "ws-1".into(), name: "First".into() }, - WorkspaceEntry { public_id: "ws-2".into(), name: "Second".into() }, - ]; - save_workspaces("default", workspaces).unwrap(); + save_workspaces("default", vec![ws("ws-1", "First"), ws("ws-2", "Second")]).unwrap(); // Set ws-2 as default — should move to front - save_default_workspace( - "default", - WorkspaceEntry { public_id: "ws-2".into(), name: "Second".into() }, - ) - .unwrap(); + save_default_workspace("default", ws("ws-2", "Second")).unwrap(); let profile = load("default").unwrap(); assert_eq!(profile.workspaces[0].public_id, "ws-2"); @@ -432,8 +379,7 @@ mod tests { #[test] fn load_missing_profile_returns_default() { let (_tmp, _guard) = with_temp_config_dir(); - - save_api_key("default", "key").unwrap(); + save_workspaces("default", vec![ws("ws-1", "WS")]).unwrap(); let profile = load("nonexistent").unwrap(); assert_eq!(profile.api_key, None); @@ -449,16 +395,49 @@ mod tests { } #[test] - fn multiple_profiles() { + fn multiple_profiles_keep_independent_workspaces() { let (_tmp, _guard) = with_temp_config_dir(); - - save_api_key("default", "key-default").unwrap(); - save_api_key("staging", "key-staging").unwrap(); + save_workspaces("default", vec![ws("ws-default", "Default WS")]).unwrap(); + save_workspaces("staging", vec![ws("ws-staging", "Staging WS")]).unwrap(); let default = load("default").unwrap(); let staging = load("staging").unwrap(); - assert_eq!(default.api_key, Some("key-default".to_string())); - assert_eq!(staging.api_key, Some("key-staging".to_string())); + assert_eq!(default.workspaces[0].public_id, "ws-default"); + assert_eq!(staging.workspaces[0].public_id, "ws-staging"); + } + + #[test] + fn legacy_api_key_in_yaml_is_ignored_on_load() { + // Older configs (pre-jwt-branch) had `api_key: hd_xxx` written + // to disk. After the migration, the api_key field is purely + // transient — `#[serde(skip)]` must drop any value present in + // YAML on load. This pins down the migration behavior so a + // stale entry can't silently reappear in profile.api_key. + let (_tmp, _guard) = with_temp_config_dir(); + let path = config_path().unwrap(); + fs::create_dir_all(path.parent().unwrap()).unwrap(); + fs::write( + &path, + "profiles:\n default:\n api_key: legacy-hd-token\n", + ) + .unwrap(); + + let profile = load("default").unwrap(); + assert_eq!(profile.api_key, None); + } + + #[test] + fn save_does_not_persist_transient_api_key() { + // Even if api_key was set in-memory (e.g. via env var), saving + // workspaces must NOT round-trip api_key into YAML. + let (_tmp, _guard) = with_temp_config_dir(); + save_workspaces("default", vec![ws("ws-1", "WS")]).unwrap(); + + let yaml = fs::read_to_string(config_path().unwrap()).unwrap(); + assert!( + !yaml.contains("api_key"), + "api_key must not appear in YAML, got:\n{yaml}" + ); } #[test] diff --git a/src/embedding.rs b/src/embedding.rs index 11ae4d4..1149f90 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -76,32 +76,28 @@ pub fn openai_embed(text: &str, model: &str) -> Vec { }); let client = reqwest::blocking::Client::new(); - let resp = match client + let req = client .post("https://api.openai.com/v1/embeddings") .header("Authorization", format!("Bearer {api_key}")) - .header("Content-Type", "application/json") - .json(&body) - .send() - { - Ok(r) => r, + .json(&body); + let (status, resp_body) = match crate::util::send_debug(&client, req, Some(&body)) { + Ok(pair) => pair, Err(e) => { eprintln!("error connecting to OpenAI API: {e}"); std::process::exit(1); } }; - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().unwrap_or_default(); - let message = serde_json::from_str::(&body) + if !status.is_success() { + let message = serde_json::from_str::(&resp_body) .ok() .and_then(|v| v["error"]["message"].as_str().map(str::to_string)) - .unwrap_or(body); + .unwrap_or(resp_body); eprintln!("error from OpenAI API ({status}): {message}"); std::process::exit(1); } - let parsed: Value = match resp.json() { + let parsed: Value = match serde_json::from_str(&resp_body) { Ok(v) => v, Err(e) => { eprintln!("error parsing OpenAI response: {e}"); diff --git a/src/jwt.rs b/src/jwt.rs new file mode 100644 index 0000000..d6c46c3 --- /dev/null +++ b/src/jwt.rs @@ -0,0 +1,939 @@ +//! JWT session management for the CLI. +//! +//! A *session* is the `{access_token, refresh_token}` pair returned by +//! `/o/token/`. Access tokens are short-lived (5 min); refresh tokens +//! last 7 days for PKCE-origin sessions, 36 h for api-token-origin. +//! +//! The session is cached in `~/.hotdata/session.json` (mode 0600). +//! Before every API call, [`ensure_access_token`] decides what to do: +//! +//! | Cached state | Action | +//! |---|---| +//! | Access token valid for > 30 s | return it directly | +//! | Access expiring or expired, refresh token valid | call `/o/token/` with `grant_type=refresh_token` | +//! | Refresh token dead, `api_key` present | re-mint via `grant_type=api_token` | +//! | Refresh token dead, no `api_key` | return an error — user must `hotdata auth` again | +//! +//! The raw `hd_...` API token (flow 3 in the design doc) is *never* +//! persisted to the session file — it stays in the main config or the +//! `HOTDATA_API_KEY` env var and is only used transiently to mint. + +use crate::config; +use crate::util; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::io::Write; +use std::path::PathBuf; +use std::time::{SystemTime, UNIX_EPOCH}; + +const CLIENT_ID: &str = "hotdata-cli"; +/// Refresh early so callers don't race an expiring token. +const REFRESH_LEEWAY_SECONDS: u64 = 30; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct Session { + pub access_token: String, + /// Unix timestamp when `access_token` expires. + pub access_expires_at: u64, + pub refresh_token: String, + /// Unix timestamp when `refresh_token` hits its absolute TTL. Not + /// precisely enforced client-side (server will reject); stored as + /// a soft hint so we know when to skip the refresh attempt and go + /// straight to the re-mint path. + pub refresh_expires_at: u64, + /// How this session was originally minted. Informational. + #[serde(default)] + pub source: String, +} + +/// Path to the session cache file. Returns `None` if the home +/// directory can't be resolved — in which case we operate without +/// caching. +pub fn session_path() -> Option { + config::config_dir().ok().map(|d| d.join("session.json")) +} + +pub fn load_session() -> Option { + let path = session_path()?; + let raw = fs::read_to_string(&path).ok()?; + serde_json::from_str(&raw).ok() +} + +pub fn save_session(session: &Session) -> Result<(), String> { + let path = session_path().ok_or_else(|| "no session path available".to_string())?; + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).map_err(|e| format!("mkdir failed: {e}"))?; + } + let json = serde_json::to_string_pretty(session).map_err(|e| format!("serialize failed: {e}"))?; + + // mode 0600 — session file contains a refresh token, treat it like a + // credential on disk. + use std::os::unix::fs::OpenOptionsExt; + let mut f = fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .mode(0o600) + .open(&path) + .map_err(|e| format!("open failed: {e}"))?; + f.write_all(json.as_bytes()) + .map_err(|e| format!("write failed: {e}"))?; + Ok(()) +} + +pub fn clear_session() { + if let Some(path) = session_path() { + let _ = fs::remove_file(path); + } +} + +fn now_unix() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +#[derive(Deserialize)] +struct TokenResponse { + access_token: String, + expires_in: u64, + refresh_token: Option, +} + +fn session_from_response(resp: TokenResponse, fallback_refresh: Option, source: &str) -> Session { + let refresh_token = resp.refresh_token.or(fallback_refresh).unwrap_or_default(); + // We don't know the exact refresh TTL server-side (7 d or 36 h + // depending on origin). Store a conservative estimate so we don't + // refresh-attempt with a known-dead token; server enforces the + // real deadline. + let refresh_ttl = if source == "api_token" { + 36 * 60 * 60 + } else { + 7 * 24 * 60 * 60 + }; + Session { + access_token: resp.access_token, + access_expires_at: now_unix() + resp.expires_in, + refresh_token, + refresh_expires_at: now_unix() + refresh_ttl, + source: source.to_string(), + } +} + +fn oauth_base(profile: &config::ProfileConfig) -> String { + // DOT (`/o/authorize/`, `/o/token/`, …) is mounted on the webapp + // (app_url), not the API. The api_url host typically only serves + // the `/v1` runtimedb routes. + profile.app_url.to_string().trim_end_matches('/').to_string() +} + +/// Build a redacted JSON view of a form body for `--debug` printing. +/// `util::send_debug` takes the printable body separately from the +/// wire body, so we hand it this masked view while the actual `.form()` +/// payload sends real values. +fn redacted_form_body(params: &[(&str, &str)]) -> serde_json::Value { + let masked: serde_json::Map = params + .iter() + .map(|(k, v)| { + let display = match *k { + "code" | "code_verifier" | "api_token" | "refresh_token" => { + util::mask_credential(v) + } + _ => v.to_string(), + }; + (k.to_string(), serde_json::Value::String(display)) + }) + .collect(); + serde_json::Value::Object(masked) +} + +/// Token-endpoint responses contain the access + refresh JWTs in +/// plaintext. Mask both before printing, but return the unredacted +/// body so the caller can still parse real values out of it. +const TOKEN_REDACT_KEYS: &[&str] = &["access_token", "refresh_token"]; + +/// Exchange a PKCE authorization code for a session. +pub fn mint_from_pkce_code( + profile: &config::ProfileConfig, + code: &str, + code_verifier: &str, + redirect_uri: &str, +) -> Result { + let url = format!("{}/o/token/", oauth_base(profile)); + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("code_verifier", code_verifier), + ("redirect_uri", redirect_uri), + ("client_id", CLIENT_ID), + ]; + + let client = reqwest::blocking::Client::new(); + let req = client.post(&url).form(¶ms); + let body_log = redacted_form_body(¶ms); + let (status, body_text) = util::send_debug_with_redaction( + &client, + req, + Some(&body_log), + TOKEN_REDACT_KEYS, + ) + .map_err(|e| format!("connection error: {e}"))?; + if !status.is_success() { + return Err(format!("token exchange failed: HTTP {status}: {body_text}")); + } + let body: TokenResponse = serde_json::from_str(&body_text) + .map_err(|e| format!("malformed token response: {e}"))?; + Ok(session_from_response(body, None, "pkce")) +} + +/// Exchange an opaque API token for a session. +pub fn mint_from_api_token( + profile: &config::ProfileConfig, + api_token: &str, +) -> Result { + let url = format!("{}/o/token/", oauth_base(profile)); + let params = [ + ("grant_type", "api_token"), + ("api_token", api_token), + ("client_id", CLIENT_ID), + ]; + + let client = reqwest::blocking::Client::new(); + let req = client.post(&url).form(¶ms); + let body_log = redacted_form_body(¶ms); + let (status, body_text) = util::send_debug_with_redaction( + &client, + req, + Some(&body_log), + TOKEN_REDACT_KEYS, + ) + .map_err(|e| format!("connection error: {e}"))?; + if !status.is_success() { + return Err(format!("api_token exchange failed: HTTP {status}: {body_text}")); + } + let body: TokenResponse = serde_json::from_str(&body_text) + .map_err(|e| format!("malformed token response: {e}"))?; + Ok(session_from_response(body, None, "api_token")) +} + +/// Refresh an existing session via the refresh-token grant. +pub fn refresh(profile: &config::ProfileConfig, session: &Session) -> Result { + let url = format!("{}/o/token/", oauth_base(profile)); + let params = [ + ("grant_type", "refresh_token"), + ("refresh_token", session.refresh_token.as_str()), + ("client_id", CLIENT_ID), + ]; + + let client = reqwest::blocking::Client::new(); + let req = client.post(&url).form(¶ms); + let body_log = redacted_form_body(¶ms); + let (status, body_text) = util::send_debug_with_redaction( + &client, + req, + Some(&body_log), + TOKEN_REDACT_KEYS, + ) + .map_err(|e| format!("connection error: {e}"))?; + if !status.is_success() { + return Err(format!("refresh failed: HTTP {status}: {body_text}")); + } + let body: TokenResponse = serde_json::from_str(&body_text) + .map_err(|e| format!("malformed token response: {e}"))?; + Ok(session_from_response( + body, + // Rotation is off server-side, so the same refresh token + // should come back — but fall back to the old one if the + // server decides to drop it from the response. + Some(session.refresh_token.clone()), + &session.source, + )) +} + +/// Return a valid access token, minting or refreshing as needed. +/// +/// The caller passes in whatever credential they want to fall back on +/// (an `hd_...` API key from `--api-key`, env var, or config). If the +/// cached session is usable it's returned without touching the API; +/// otherwise the session is refreshed/re-minted and persisted. +pub fn ensure_access_token( + profile: &config::ProfileConfig, + api_key_fallback: Option<&str>, +) -> Result { + // 0) An explicit identity override (`--api-key`, `HOTDATA_API_KEY`, + // or `.env`) is asserting a specific identity for *this invocation*. + // The on-disk session may belong to a completely different user + // from a prior `hotdata auth` and must not be reused. Mint fresh + // and deliberately skip persisting so we don't clobber the + // interactive session. Surface the real mint error here too — if + // the override key is bad, "HTTP 401" is more useful than the + // generic "session expired" message the cache-fallthrough returns. + // + // Only `ApiKeySource::Config` continues to honor the cache: that's + // a stable identity persisted in config.yml, paired with a session + // minted from that same identity. + if matches!( + profile.api_key_source, + config::ApiKeySource::Flag | config::ApiKeySource::Env + ) { + if let Some(api_key) = api_key_fallback { + let session = mint_from_api_token(profile, api_key)?; + return Ok(session.access_token); + } + } + + let now = now_unix(); + + // 1) Cached session is still good. + if let Some(session) = load_session() { + if !session.access_token.is_empty() && now + REFRESH_LEEWAY_SECONDS < session.access_expires_at { + return Ok(session.access_token); + } + + // 2) Access expired but refresh might still work. + if !session.refresh_token.is_empty() && now < session.refresh_expires_at { + match refresh(profile, &session) { + Ok(new_session) => { + let tok = new_session.access_token.clone(); + let _ = save_session(&new_session); + return Ok(tok); + } + Err(_) => { + // Refresh rejected — fall through to re-mint. + clear_session(); + } + } + } + } + + // 3) No cache, or refresh is dead → need a fresh mint. + if let Some(api_key) = api_key_fallback { + match mint_from_api_token(profile, api_key) { + Ok(session) => { + let tok = session.access_token.clone(); + save_session(&session)?; + return Ok(tok); + } + Err(_) => { + // API token rejected (revoked, expired, or invalid). + // Fall through to the re-auth hint — hide the raw HTTP + // error from the user; the api.rs caller appends a + // `hotdata auth` hint. + } + } + } + + Err("session expired or revoked".into()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{ApiUrl, AppUrl, ProfileConfig, test_helpers::with_temp_config_dir}; + + fn mock_profile(url: &str) -> ProfileConfig { + ProfileConfig { + app_url: AppUrl(Some(url.to_string())), + api_url: ApiUrl(Some(url.to_string())), + ..Default::default() + } + } + + fn cached_session(access_offset: i64, refresh_offset: i64) -> Session { + let now = now_unix() as i64; + Session { + access_token: "cached-jwt".into(), + access_expires_at: (now + access_offset).max(0) as u64, + refresh_token: "cached-refresh".into(), + refresh_expires_at: (now + refresh_offset).max(0) as u64, + source: "pkce".into(), + } + } + + // --- session persistence --- + + #[test] + fn session_round_trip() { + let (_tmp, _guard) = with_temp_config_dir(); + let s = Session { + access_token: "a".into(), + access_expires_at: 100, + refresh_token: "r".into(), + refresh_expires_at: 200, + source: "pkce".into(), + }; + save_session(&s).unwrap(); + let loaded = load_session().unwrap(); + assert_eq!(loaded.access_token, "a"); + assert_eq!(loaded.access_expires_at, 100); + assert_eq!(loaded.refresh_token, "r"); + assert_eq!(loaded.refresh_expires_at, 200); + assert_eq!(loaded.source, "pkce"); + } + + #[test] + fn session_file_is_mode_0600() { + use std::os::unix::fs::PermissionsExt; + let (_tmp, _guard) = with_temp_config_dir(); + save_session(&Session::default()).unwrap(); + let path = session_path().unwrap(); + let mode = fs::metadata(&path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600, "session file must be 0600 (contains refresh token)"); + } + + #[test] + fn load_session_returns_none_when_missing() { + let (_tmp, _guard) = with_temp_config_dir(); + assert!(load_session().is_none()); + } + + #[test] + fn load_session_returns_none_when_corrupt() { + let (_tmp, _guard) = with_temp_config_dir(); + let path = session_path().unwrap(); + fs::create_dir_all(path.parent().unwrap()).unwrap(); + fs::write(&path, "not json").unwrap(); + assert!(load_session().is_none()); + } + + #[test] + fn clear_session_removes_file() { + let (_tmp, _guard) = with_temp_config_dir(); + save_session(&Session::default()).unwrap(); + assert!(load_session().is_some()); + clear_session(); + assert!(load_session().is_none()); + // Idempotent — clearing again is a no-op. + clear_session(); + } + + // --- mint_from_pkce_code --- + + #[test] + fn mint_from_pkce_code_success() { + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("grant_type".into(), "authorization_code".into()), + mockito::Matcher::UrlEncoded("code".into(), "auth-code".into()), + mockito::Matcher::UrlEncoded("code_verifier".into(), "verifier".into()), + mockito::Matcher::UrlEncoded( + "redirect_uri".into(), + "http://127.0.0.1:1234/".into(), + ), + mockito::Matcher::UrlEncoded("client_id".into(), "hotdata-cli".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{"access_token":"jwt-abc","expires_in":300,"refresh_token":"refresh-xyz"}"#, + ) + .create(); + + let profile = mock_profile(&server.url()); + let session = + mint_from_pkce_code(&profile, "auth-code", "verifier", "http://127.0.0.1:1234/") + .unwrap(); + m.assert(); + assert_eq!(session.access_token, "jwt-abc"); + assert_eq!(session.refresh_token, "refresh-xyz"); + assert_eq!(session.source, "pkce"); + assert!(session.access_expires_at > now_unix()); + // PKCE-origin sessions get the 7-day refresh TTL hint. + let ttl = session.refresh_expires_at.saturating_sub(now_unix()); + assert!(ttl >= 7 * 24 * 60 * 60 - 5 && ttl <= 7 * 24 * 60 * 60 + 5); + } + + #[test] + fn mint_from_pkce_code_trims_trailing_slash_in_app_url() { + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"a","expires_in":1,"refresh_token":"r"}"#) + .create(); + + // Append a trailing slash — oauth_base must strip it so we don't + // end up POSTing to `//o/token/`. + let url = format!("{}/", server.url()); + let profile = mock_profile(&url); + mint_from_pkce_code(&profile, "c", "v", "uri").unwrap(); + m.assert(); + } + + #[test] + fn mint_from_pkce_code_http_error_includes_status_and_body() { + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .with_status(403) + .with_body("forbidden by policy") + .create(); + + let profile = mock_profile(&server.url()); + let err = mint_from_pkce_code(&profile, "c", "v", "uri").unwrap_err(); + m.assert(); + assert!(err.contains("403"), "got: {err}"); + assert!(err.contains("forbidden by policy"), "got: {err}"); + } + + #[test] + fn mint_from_pkce_code_malformed_response() { + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .with_status(200) + .with_body("not json") + .create(); + + let profile = mock_profile(&server.url()); + let err = mint_from_pkce_code(&profile, "c", "v", "uri").unwrap_err(); + m.assert(); + assert!(err.contains("malformed"), "got: {err}"); + } + + #[test] + fn mint_from_pkce_code_connection_error() { + let profile = mock_profile("http://127.0.0.1:1"); + let err = mint_from_pkce_code(&profile, "c", "v", "uri").unwrap_err(); + assert!(err.contains("connection"), "got: {err}"); + } + + // --- mint_from_api_token --- + + #[test] + fn mint_from_api_token_success_uses_36h_refresh_ttl() { + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("grant_type".into(), "api_token".into()), + mockito::Matcher::UrlEncoded("api_token".into(), "hd_xyz".into()), + mockito::Matcher::UrlEncoded("client_id".into(), "hotdata-cli".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"jwt-1","expires_in":300,"refresh_token":"r1"}"#) + .create(); + + let profile = mock_profile(&server.url()); + let session = mint_from_api_token(&profile, "hd_xyz").unwrap(); + m.assert(); + assert_eq!(session.access_token, "jwt-1"); + assert_eq!(session.refresh_token, "r1"); + assert_eq!(session.source, "api_token"); + // api_token-origin sessions get the shorter 36h refresh TTL hint. + let ttl = session.refresh_expires_at.saturating_sub(now_unix()); + assert!(ttl >= 36 * 60 * 60 - 5 && ttl <= 36 * 60 * 60 + 5); + } + + #[test] + fn mint_from_api_token_http_error() { + let mut server = mockito::Server::new(); + let m = server.mock("POST", "/o/token/").with_status(401).create(); + + let profile = mock_profile(&server.url()); + let err = mint_from_api_token(&profile, "bad-key").unwrap_err(); + m.assert(); + assert!(err.contains("401"), "got: {err}"); + } + + // --- refresh --- + + #[test] + fn refresh_keeps_old_refresh_token_when_server_omits_it() { + // Rotation-off case: server returns no refresh_token, and we + // must carry the old one forward so the next refresh works. + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("grant_type".into(), "refresh_token".into()), + mockito::Matcher::UrlEncoded("refresh_token".into(), "stable-refresh".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"new-jwt","expires_in":300}"#) + .create(); + + let profile = mock_profile(&server.url()); + let session = Session { + refresh_token: "stable-refresh".into(), + source: "pkce".into(), + ..Default::default() + }; + let new_session = refresh(&profile, &session).unwrap(); + m.assert(); + assert_eq!(new_session.access_token, "new-jwt"); + assert_eq!(new_session.refresh_token, "stable-refresh"); + // Source is carried over from the original session. + assert_eq!(new_session.source, "pkce"); + } + + #[test] + fn refresh_uses_rotated_refresh_token_when_server_returns_one() { + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{"access_token":"new-jwt","expires_in":300,"refresh_token":"rotated"}"#, + ) + .create(); + + let profile = mock_profile(&server.url()); + let session = Session { + refresh_token: "old".into(), + source: "api_token".into(), + ..Default::default() + }; + let new_session = refresh(&profile, &session).unwrap(); + m.assert(); + assert_eq!(new_session.refresh_token, "rotated"); + assert_eq!(new_session.source, "api_token"); + } + + #[test] + fn refresh_http_error() { + let mut server = mockito::Server::new(); + let m = server.mock("POST", "/o/token/").with_status(400).create(); + + let profile = mock_profile(&server.url()); + let session = Session { + refresh_token: "x".into(), + ..Default::default() + }; + let err = refresh(&profile, &session).unwrap_err(); + m.assert(); + assert!(err.contains("400"), "got: {err}"); + } + + // --- ensure_access_token: each branch of the decision table --- + + #[test] + fn ensure_returns_cached_token_without_http_when_valid() { + let (_tmp, _guard) = with_temp_config_dir(); + // 10 min into the future, well past REFRESH_LEEWAY_SECONDS. + save_session(&cached_session(600, 7 * 24 * 3600)).unwrap(); + + // Profile points at a port that's not listening — if the code + // reached out to the network this would surface as an error. + let profile = mock_profile("http://127.0.0.1:1"); + let token = ensure_access_token(&profile, None).unwrap(); + assert_eq!(token, "cached-jwt"); + } + + #[test] + fn ensure_refreshes_when_inside_leeway_window() { + // Token still has a few seconds left but is inside the 30s + // leeway, so the orchestrator should refresh proactively. + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "refresh_token".into(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"refreshed-jwt","expires_in":300}"#) + .create(); + + save_session(&cached_session(5, 86400)).unwrap(); + let profile = mock_profile(&server.url()); + let token = ensure_access_token(&profile, None).unwrap(); + m.assert(); + assert_eq!(token, "refreshed-jwt"); + // New session was persisted to disk. + assert_eq!(load_session().unwrap().access_token, "refreshed-jwt"); + } + + #[test] + fn ensure_refreshes_when_access_expired() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "refresh_token".into(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"refreshed-jwt","expires_in":300}"#) + .create(); + + save_session(&cached_session(-10, 86400)).unwrap(); + let profile = mock_profile(&server.url()); + let token = ensure_access_token(&profile, None).unwrap(); + m.assert(); + assert_eq!(token, "refreshed-jwt"); + } + + #[test] + fn ensure_falls_back_to_api_token_mint_when_refresh_rejected() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let refresh_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "refresh_token".into(), + )) + .with_status(400) + .with_body("invalid_grant") + .create(); + let mint_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "api_token".into(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"reminted-jwt","expires_in":300,"refresh_token":"r2"}"#) + .create(); + + save_session(&cached_session(-10, 86400)).unwrap(); + let profile = mock_profile(&server.url()); + let token = ensure_access_token(&profile, Some("hd_xyz")).unwrap(); + refresh_mock.assert(); + mint_mock.assert(); + assert_eq!(token, "reminted-jwt"); + let loaded = load_session().unwrap(); + assert_eq!(loaded.access_token, "reminted-jwt"); + assert_eq!(loaded.source, "api_token"); + } + + #[test] + fn ensure_mints_from_api_token_when_no_session() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "api_token".into(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"fresh-jwt","expires_in":300,"refresh_token":"r"}"#) + .create(); + + let profile = mock_profile(&server.url()); + let token = ensure_access_token(&profile, Some("hd_xyz")).unwrap(); + m.assert(); + assert_eq!(token, "fresh-jwt"); + assert_eq!(load_session().unwrap().access_token, "fresh-jwt"); + } + + #[test] + fn ensure_skips_refresh_when_refresh_ttl_expired() { + // Refresh token is past its soft TTL — the orchestrator should + // skip the refresh attempt entirely and go straight to the + // api_token re-mint path. + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let mint_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "api_token".into(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"reminted","expires_in":300,"refresh_token":"r"}"#) + .expect(1) + .create(); + // Refresh path must NOT be hit. + let refresh_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "refresh_token".into(), + )) + .expect(0) + .create(); + + save_session(&cached_session(-10, -10)).unwrap(); + let profile = mock_profile(&server.url()); + let token = ensure_access_token(&profile, Some("hd_xyz")).unwrap(); + mint_mock.assert(); + refresh_mock.assert(); + assert_eq!(token, "reminted"); + } + + #[test] + fn ensure_errors_when_no_session_and_no_api_key() { + let (_tmp, _guard) = with_temp_config_dir(); + let profile = mock_profile("http://127.0.0.1:1"); + let err = ensure_access_token(&profile, None).unwrap_err(); + assert!(err.contains("session"), "got: {err}"); + } + + #[test] + fn ensure_errors_when_api_token_rejected() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::UrlEncoded( + "grant_type".into(), + "api_token".into(), + )) + .with_status(401) + .create(); + + let profile = mock_profile(&server.url()); + let err = ensure_access_token(&profile, Some("revoked")).unwrap_err(); + m.assert(); + // Error is the generic "session expired or revoked" — the raw + // HTTP status is suppressed so api.rs can append a clean + // re-auth hint. + assert!(err.contains("session"), "got: {err}"); + } + + // --- ensure_access_token: --api-key (Flag source) overrides cache --- + + #[test] + fn ensure_with_flag_source_bypasses_valid_cached_session() { + // A perfectly valid PKCE session is on disk, but the user + // passed --api-key — we must mint a fresh JWT from that key + // instead of reusing the cached session. + let (_tmp, _guard) = with_temp_config_dir(); + save_session(&cached_session(3600, 7 * 24 * 3600)).unwrap(); + + let mut server = mockito::Server::new(); + let mint_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("grant_type".into(), "api_token".into()), + mockito::Matcher::UrlEncoded("api_token".into(), "hd_flag_token".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"flag-jwt","expires_in":300,"refresh_token":"r"}"#) + .create(); + + let mut profile = mock_profile(&server.url()); + profile.api_key_source = config::ApiKeySource::Flag; + let token = ensure_access_token(&profile, Some("hd_flag_token")).unwrap(); + mint_mock.assert(); + assert_eq!(token, "flag-jwt"); + } + + #[test] + fn ensure_with_flag_source_does_not_overwrite_cached_session() { + // Flag-driven mints are for one-shot CLI invocations; persisting + // them would silently log the interactive user out. + let (_tmp, _guard) = with_temp_config_dir(); + let original = cached_session(3600, 7 * 24 * 3600); + save_session(&original).unwrap(); + + let mut server = mockito::Server::new(); + let _mint = server + .mock("POST", "/o/token/") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"flag-jwt","expires_in":300,"refresh_token":"r"}"#) + .create(); + + let mut profile = mock_profile(&server.url()); + profile.api_key_source = config::ApiKeySource::Flag; + ensure_access_token(&profile, Some("hd_flag_token")).unwrap(); + + // session.json must still hold the original PKCE session. + let after = load_session().unwrap(); + assert_eq!(after.access_token, original.access_token); + assert_eq!(after.refresh_token, original.refresh_token); + } + + #[test] + fn ensure_with_flag_source_surfaces_mint_error() { + // When --api-key was passed explicitly, the user wants the real + // failure reason, not the generic "session expired or revoked" + // message that the cache-fall-through path returns. + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server + .mock("POST", "/o/token/") + .with_status(401) + .with_body("invalid api token") + .create(); + + let mut profile = mock_profile(&server.url()); + profile.api_key_source = config::ApiKeySource::Flag; + let err = ensure_access_token(&profile, Some("bad")).unwrap_err(); + m.assert(); + assert!(err.contains("401"), "got: {err}"); + } + + #[test] + fn ensure_with_env_source_bypasses_valid_cached_session() { + // HOTDATA_API_KEY (whether exported in the shell or loaded + // from .env) must override a cached session for the same + // reason --api-key does: the env var asserts a specific + // identity for this invocation. + let (_tmp, _guard) = with_temp_config_dir(); + let original = cached_session(3600, 7 * 24 * 3600); + save_session(&original).unwrap(); + + let mut server = mockito::Server::new(); + let mint_mock = server + .mock("POST", "/o/token/") + .match_body(mockito::Matcher::AllOf(vec![ + mockito::Matcher::UrlEncoded("grant_type".into(), "api_token".into()), + mockito::Matcher::UrlEncoded("api_token".into(), "hd_env_token".into()), + ])) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"access_token":"env-jwt","expires_in":300,"refresh_token":"r"}"#) + .create(); + + let mut profile = mock_profile(&server.url()); + profile.api_key_source = config::ApiKeySource::Env; + let token = ensure_access_token(&profile, Some("hd_env_token")).unwrap(); + mint_mock.assert(); + assert_eq!(token, "env-jwt"); + + // Cached session must remain untouched — same no-clobber + // guarantee as the Flag path. + let after = load_session().unwrap(); + assert_eq!(after.access_token, original.access_token); + } + + #[test] + fn ensure_with_config_source_still_uses_cached_session() { + // Regression guard: api_key_source = Config (the default) must + // continue to short-circuit on a valid cache, not mint. + let (_tmp, _guard) = with_temp_config_dir(); + save_session(&cached_session(3600, 7 * 24 * 3600)).unwrap(); + + let profile = mock_profile("http://127.0.0.1:1"); + // Config source — even with an api_key fallback, the cache wins. + assert_eq!(profile.api_key_source, config::ApiKeySource::Config); + let token = ensure_access_token(&profile, Some("hd_config_key")).unwrap(); + assert_eq!(token, "cached-jwt"); + } + + #[test] + fn ensure_clears_session_when_refresh_dies_with_no_fallback() { + let (_tmp, _guard) = with_temp_config_dir(); + let mut server = mockito::Server::new(); + let m = server.mock("POST", "/o/token/").with_status(400).create(); + + save_session(&cached_session(-10, 86400)).unwrap(); + let profile = mock_profile(&server.url()); + let err = ensure_access_token(&profile, None).unwrap_err(); + m.assert(); + assert!(err.contains("session"), "got: {err}"); + // Stale session must be cleared so the next attempt doesn't + // burn a network call on the same dead refresh token. + assert!(load_session().is_none()); + } +} diff --git a/src/main.rs b/src/main.rs index 4b47770..6b04cd1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ mod datasets; mod embedding; mod indexes; mod jobs; +mod jwt; mod queries; mod query; mod results; diff --git a/src/skill.rs b/src/skill.rs index a3a6d87..7b65a0f 100644 --- a/src/skill.rs +++ b/src/skill.rs @@ -85,6 +85,10 @@ fn download_and_extract() -> Result<(), String> { let url = download_url(); println!("Downloading skill..."); + // Binary download — can't route through `send_debug` (which calls + // `resp.text()` and would corrupt the gzip stream). Log the + // request line manually so `--debug` still shows the URL. + crate::util::debug_request("GET", &url, &[], None); let client = reqwest::blocking::Client::new(); let resp = client .get(&url) diff --git a/src/util.rs b/src/util.rs index 8204f1f..7a74f07 100644 --- a/src/util.rs +++ b/src/util.rs @@ -36,21 +36,35 @@ pub fn debug_request(method: &str, url: &str, headers: &[(&str, &str)], body: Op } } -/// Log response status and body when debug mode is enabled. -/// Consumes the response and returns the status + body text for the caller to parse. -pub fn debug_response(resp: reqwest::blocking::Response) -> (reqwest::StatusCode, String) { +/// Log response status and body when debug mode is enabled. Consumes +/// the response and returns the status + body text for the caller to +/// parse. `redact_keys` masks the named JSON fields in the printed +/// body (last 4 chars only) — pass `&[]` for no redaction. The +/// returned body string is *unredacted* so the caller can still parse +/// real values out of it. +pub fn debug_response_redacted( + resp: reqwest::blocking::Response, + redact_keys: &[&str], +) -> (reqwest::StatusCode, String) { let status = resp.status(); let body = resp.text().unwrap_or_default(); if is_debug() { use crossterm::style::Stylize; - let status_str = format!("<<< {} {}", status.as_u16(), status.canonical_reason().unwrap_or("")); + let status_str = format!( + "<<< {} {}", + status.as_u16(), + status.canonical_reason().unwrap_or("") + ); if status.is_success() { eprintln!("{}", status_str.dark_green()); } else { eprintln!("{}", status_str.dark_red()); } - if let Ok(v) = serde_json::from_str::(&body) { + if let Ok(mut v) = serde_json::from_str::(&body) { + if !redact_keys.is_empty() { + redact_json_fields(&mut v, redact_keys); + } eprintln!("{}", colorize_json(&serde_json::to_string_pretty(&v).unwrap())); } else if !body.is_empty() { eprintln!("{}", body.to_string().dark_grey()); @@ -60,6 +74,118 @@ pub fn debug_response(resp: reqwest::blocking::Response) -> (reqwest::StatusCode (status, body) } +/// Mask a credential to its first 4 characters (`XXXX...`), or `***` +/// if the value is too short to safely reveal a head. +pub fn mask_credential(s: &str) -> String { + if s.len() > 4 { + format!("{}...", &s[..4]) + } else { + "***".into() + } +} + +/// Canonical wrapper for every outgoing HTTP call in the CLI. Builds +/// the request, logs it under `--debug` (with `Authorization` auto- +/// masked), executes, and prints + returns the response. Callers stay +/// minimal: +/// +/// ```ignore +/// let req = client.get(&url).header("Authorization", bearer); +/// let (status, body) = util::send_debug(&client, req, None)?; +/// ``` +/// +/// `body_for_log` is the *printable* form of the request body — pass +/// `None` for GET, the JSON `Value` for `.json(...)` calls, or a hand- +/// rolled redacted `Value` for form bodies. +pub fn send_debug( + client: &reqwest::blocking::Client, + builder: reqwest::blocking::RequestBuilder, + body_for_log: Option<&serde_json::Value>, +) -> reqwest::Result<(reqwest::StatusCode, String)> { + send_debug_with_redaction(client, builder, body_for_log, &[]) +} + +/// Like `send_debug` but masks the named JSON keys in the printed +/// response body. The returned body string is always unredacted. +pub fn send_debug_with_redaction( + client: &reqwest::blocking::Client, + builder: reqwest::blocking::RequestBuilder, + body_for_log: Option<&serde_json::Value>, + response_redact_keys: &[&str], +) -> reqwest::Result<(reqwest::StatusCode, String)> { + let request = builder.build()?; + if is_debug() { + log_request_struct(&request, body_for_log); + } + let resp = client.execute(request)?; + Ok(debug_response_redacted(resp, response_redact_keys)) +} + +fn log_request_struct( + req: &reqwest::blocking::Request, + body: Option<&serde_json::Value>, +) { + let method = req.method().as_str(); + let url = req.url().as_str(); + // Materialize masked header pairs as owned strings, then re-borrow + // for `debug_request` (which takes &[(&str, &str)]). + let pairs: Vec<(String, String)> = req + .headers() + .iter() + .filter_map(|(k, v)| { + v.to_str().ok().map(|s| { + let key = k.as_str(); + let val = if key.eq_ignore_ascii_case("authorization") { + mask_auth_value(s) + } else { + s.to_string() + }; + (key.to_string(), val) + }) + }) + .collect(); + let refs: Vec<(&str, &str)> = pairs + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + debug_request(method, url, &refs, body); +} + +/// Mask an `Authorization` header value. Preserves the scheme prefix +/// (`Bearer`, `Basic`, …) so the log still makes sense. +fn mask_auth_value(value: &str) -> String { + if let Some(token) = value.strip_prefix("Bearer ") { + format!("Bearer {}", mask_credential(token)) + } else { + mask_credential(value) + } +} + +/// Walk a JSON value and replace string values under any of the named +/// keys with their masked form. Recurses into nested objects/arrays so +/// callers don't have to know the shape of the response. +fn redact_json_fields(v: &mut serde_json::Value, keys: &[&str]) { + match v { + serde_json::Value::Object(map) => { + for (k, val) in map.iter_mut() { + if keys.contains(&k.as_str()) { + if let Some(s) = val.as_str() { + *val = serde_json::Value::String(mask_credential(s)); + } + } else { + redact_json_fields(val, keys); + } + } + } + serde_json::Value::Array(arr) => { + for item in arr.iter_mut() { + redact_json_fields(item, keys); + } + } + _ => {} + } +} + /// Colorize a pretty-printed JSON string for terminal output. fn colorize_json(json: &str) -> String { use crossterm::style::Stylize; @@ -151,6 +277,70 @@ pub fn format_date(s: &str) -> String { s.chars().take(16).collect() } +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn mask_credential_long() { + assert_eq!(mask_credential("abcdefgh"), "abcd..."); + } + + #[test] + fn mask_credential_short() { + assert_eq!(mask_credential("abcd"), "***"); + assert_eq!(mask_credential(""), "***"); + } + + #[test] + fn redact_json_fields_top_level() { + let mut v = json!({ + "access_token": "long-secret-token", + "expires_in": 300, + "refresh_token": "another-secret" + }); + redact_json_fields(&mut v, &["access_token", "refresh_token"]); + assert_eq!(v["access_token"], "long..."); + assert_eq!(v["refresh_token"], "anot..."); + // Non-redacted keys untouched. + assert_eq!(v["expires_in"], 300); + } + + #[test] + fn redact_json_fields_recurses_into_nested_objects_and_arrays() { + let mut v = json!({ + "data": { + "access_token": "secret-1234", + "items": [ + {"access_token": "nested-secret"} + ] + } + }); + redact_json_fields(&mut v, &["access_token"]); + assert_eq!(v["data"]["access_token"], "secr..."); + assert_eq!(v["data"]["items"][0]["access_token"], "nest..."); + } + + #[test] + fn redact_json_fields_no_match_is_noop() { + let mut v = json!({"foo": "bar"}); + let original = v.clone(); + redact_json_fields(&mut v, &["access_token"]); + assert_eq!(v, original); + } + + #[test] + fn redact_json_fields_skips_non_string_values() { + // If a key matches but the value isn't a string, leave it + // alone — we can't meaningfully mask a number/null/object. + let mut v = json!({"access_token": null, "refresh_token": 123}); + redact_json_fields(&mut v, &["access_token", "refresh_token"]); + assert_eq!(v["access_token"], serde_json::Value::Null); + assert_eq!(v["refresh_token"], 123); + } +} + pub fn api_error(body: String) -> String { serde_json::from_str::(&body) .ok()