diff --git a/guards/github-guard/rust-guard/src/labels/tool_rules.rs b/guards/github-guard/rust-guard/src/labels/tool_rules.rs index 1ed4a4e3..56238fe1 100644 --- a/guards/github-guard/rust-guard/src/labels/tool_rules.rs +++ b/guards/github-guard/rust-guard/src/labels/tool_rules.rs @@ -745,17 +745,25 @@ fn check_file_secrecy( ctx: &PolicyContext, ) -> Vec { let path_lower = path.to_lowercase(); - let segments: Vec<&str> = path_lower.split('/').collect(); // Check for sensitive file extensions/names - for pattern in SENSITIVE_FILE_PATTERNS { - if path_lower.ends_with(pattern) || segments.iter().any(|seg| seg.starts_with(*pattern)) { - return policy_private_scope_label(owner, repo, repo_id, ctx); - } + if SENSITIVE_FILE_PATTERNS + .iter() + .any(|pattern| path_lower.ends_with(pattern)) + { + return policy_private_scope_label(owner, repo, repo_id, ctx); + } + + if path_lower.split('/').any(|seg| { + SENSITIVE_FILE_PATTERNS + .iter() + .any(|pattern| seg.starts_with(*pattern)) + }) { + return policy_private_scope_label(owner, repo, repo_id, ctx); } // Get filename - let filename = segments.last().copied().unwrap_or(path_lower.as_str()); + let filename = path_lower.rsplit('/').next().unwrap_or(&path_lower); // Check for sensitive keywords in filename for keyword in SENSITIVE_FILE_KEYWORDS { diff --git a/guards/github-guard/rust-guard/src/lib.rs b/guards/github-guard/rust-guard/src/lib.rs index a4fab232..88a56a57 100644 --- a/guards/github-guard/rust-guard/src/lib.rs +++ b/guards/github-guard/rust-guard/src/lib.rs @@ -21,6 +21,7 @@ use labels::{ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::alloc::{alloc as std_alloc, dealloc as std_dealloc, Layout}; +use std::borrow::Cow; use std::slice; use std::sync::Mutex; @@ -257,9 +258,13 @@ struct LabelResponseOutput { items: Vec, } -fn infer_scope_for_baseline(tool_name: &str, tool_args: &Value, repo_id: &str) -> String { +fn infer_scope_for_baseline<'a>( + tool_name: &str, + tool_args: &Value, + repo_id: &'a str, +) -> Cow<'a, str> { if !repo_id.is_empty() { - return repo_id.to_string(); + return Cow::Borrowed(repo_id); } match tool_name { @@ -268,16 +273,16 @@ fn infer_scope_for_baseline(tool_name: &str, tool_args: &Value, repo_id: &str) - | "manage_notification_subscription" | "manage_repository_notification_subscription" | "create_repository" - | "fork_repository" => scope_names::GITHUB.to_string(), + | "fork_repository" => Cow::Borrowed(scope_names::GITHUB), "search_code" | "search_issues" | "search_pull_requests" => { let query = tool_args .get("query") .and_then(|v| v.as_str()) .unwrap_or(""); let (_, _, repo_from_query) = extract_repo_info_from_search_query(query); - repo_from_query + Cow::Owned(repo_from_query) } - _ => String::new(), + _ => Cow::Borrowed(""), } } @@ -1096,9 +1101,34 @@ mod tests { let tool_args = json!({"query": "repo:lpcox/github-guard README"}); let inferred = infer_scope_for_baseline("search_code", &tool_args, ""); + assert!(matches!(&inferred, Cow::Owned(_))); assert_eq!(inferred, "lpcox/github-guard"); } + #[test] + fn infer_scope_for_baseline_borrows_repo_id_when_present() { + let tool_args = json!({}); + let inferred = infer_scope_for_baseline("get_file_contents", &tool_args, "octocat/hello-world"); + + assert!(matches!(inferred, Cow::Borrowed("octocat/hello-world"))); + } + + #[test] + fn infer_scope_for_baseline_borrows_github_scope_for_repo_creation() { + let tool_args = json!({}); + let inferred = infer_scope_for_baseline("create_repository", &tool_args, ""); + + assert!(matches!(inferred, Cow::Borrowed(scope_names::GITHUB))); + } + + #[test] + fn infer_scope_for_baseline_borrows_empty_scope_for_other_tools() { + let tool_args = json!({}); + let inferred = infer_scope_for_baseline("get_file_contents", &tool_args, ""); + + assert!(matches!(inferred, Cow::Borrowed(""))); + } + #[test] fn search_code_baseline_preserves_scoped_integrity() { let ctx = PolicyContext {