Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions src/llm-coding-tools-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,39 +59,3 @@ impl From<globset::Error> for ToolError {
ToolError::InvalidPattern(e.to_string())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn tool_error_displays_io_error() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let err: ToolError = io_err.into();
assert!(err.to_string().contains("I/O error"));
}

#[test]
fn tool_error_displays_invalid_path() {
let err = ToolError::InvalidPath("not absolute".into());
assert!(err.to_string().contains("invalid path"));
}

#[test]
fn tool_error_from_glob_pattern_error() {
let glob_err = globset::Glob::new("[invalid").unwrap_err();
let err: ToolError = glob_err.into();
assert!(matches!(err, ToolError::InvalidPattern(_)));
}

#[test]
fn timeout_with_kill_failure_displays_both_contexts() {
let err = ToolError::TimeoutWithKillFailure {
message: "command timed out after 100ms".into(),
kill_error: "permission denied".into(),
};
let display = err.to_string();
assert!(display.contains("command timed out after 100ms"));
assert!(display.contains("kill failed: permission denied"));
}
}
33 changes: 17 additions & 16 deletions src/llm-coding-tools-core/src/internal/hash64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,22 @@ pub(crate) fn hash_u64_bytes(bytes: &[u8]) -> Hash64 {
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn hash_is_deterministic() {
let hash1 = hash_u64("bash");
let hash2 = hash_u64("bash");
assert_eq!(hash1, hash2);
}

#[test]
fn different_inputs_produce_different_hashes() {
let h1 = hash_u64("bash");
let h2 = hash_u64("read");
let h3 = hash_u64("write");
assert_ne!(h1, h2);
assert_ne!(h1, h3);
assert_ne!(h2, h3);
use rstest::rstest;

/// Verifies that the hash function is deterministic for identical inputs
/// and produces different hashes for different inputs.
#[rstest]
#[case::same_input("bash", "bash", true)]
#[case::different_inputs("bash", "read", false)]
#[case::different_inputs_2("bash", "write", false)]
#[case::different_inputs_3("read", "write", false)]
fn hash_properties(#[case] a: &str, #[case] b: &str, #[case] should_equal: bool) {
let hash_a = hash_u64(a);
let hash_b = hash_u64(b);
if should_equal {
assert_eq!(hash_a, hash_b);
} else {
assert_ne!(hash_a, hash_b);
}
}
}
21 changes: 13 additions & 8 deletions src/llm-coding-tools-core/src/models/provider_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,24 @@ impl ProviderType {
#[cfg(test)]
mod tests {
use super::ProviderType;
use rstest::rstest;

#[test]
fn unknown_is_default_variant() {
assert_eq!(ProviderType::default(), ProviderType::Unknown);
}

#[test]
fn azure_requires_base_url() {
assert!(ProviderType::Azure.requires_base_url());
}

#[test]
fn ollama_does_not_require_api_key() {
assert!(!ProviderType::Ollama.requires_api_key());
/// Verifies that provider type flags return expected values.
#[rstest]
#[case::azure_requires_base_url(ProviderType::Azure, true, true)]
#[case::ollama_no_api_key(ProviderType::Ollama, false, false)]
#[case::openai_requires_api_key(ProviderType::OpenAiCompletions, false, true)]
fn provider_type_flags(
#[case] provider: ProviderType,
#[case] requires_base_url: bool,
#[case] requires_api_key: bool,
) {
assert_eq!(provider.requires_base_url(), requires_base_url);
assert_eq!(provider.requires_api_key(), requires_api_key);
}
}
57 changes: 29 additions & 28 deletions src/llm-coding-tools-core/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,38 +59,39 @@ impl From<WebFetchOutput> for ToolOutput {
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;

#[test]
fn tool_output_new_creates_non_truncated() {
let output = ToolOutput::new("content");
assert_eq!(output.content, "content");
assert!(!output.truncated);
}

#[test]
fn tool_output_truncated_marks_truncated() {
let output = ToolOutput::truncated("partial");
assert!(output.truncated);
}

#[test]
fn tool_output_from_string() {
let output: ToolOutput = "hello".into();
assert_eq!(output.content, "hello");
}

#[test]
fn tool_output_serializes_without_truncated_when_false() {
let output = ToolOutput::new("content");
let json = serde_json::to_string(&output).unwrap();
assert!(!json.contains("truncated"));
/// Verifies that ToolOutput constructors correctly set the truncated flag.
#[rstest]
#[case::new_creates_non_truncated(false, "content")]
#[case::truncated_marks_truncated(true, "partial")]
fn tool_output_creation(#[case] is_truncated: bool, #[case] content: &str) {
let output = if is_truncated {
ToolOutput::truncated(content)
} else {
ToolOutput::new(content)
};
assert_eq!(output.content, content);
assert_eq!(output.truncated, is_truncated);
}

#[test]
fn tool_output_serializes_with_truncated_when_true() {
let output = ToolOutput::truncated("content");
/// Verifies that the truncated field is only serialized when true.
/// ToolOutput uses `#[serde(skip_serializing_if)]` to omit the field
/// when false, producing cleaner JSON output.
///
/// We verify this behaviour specifically to ensure the LLM does not receive
/// unnecessary tokens for default values that provide no information.
#[rstest]
#[case::without_truncated_when_false(false)]
#[case::with_truncated_when_true(true)]
fn tool_output_serialization(#[case] truncated: bool) {
let output = if truncated {
ToolOutput::truncated("content")
} else {
ToolOutput::new("content")
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("truncated"));
assert_eq!(json.contains("truncated"), truncated);
}

#[cfg(any(feature = "tokio", feature = "blocking"))]
Expand Down
93 changes: 37 additions & 56 deletions src/llm-coding-tools-core/src/path/allowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ impl PathResolver for AllowedPathResolver {
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::fs;
use tempfile::TempDir;

Expand All @@ -136,64 +137,43 @@ mod tests {
dir
}

#[test]
fn resolves_relative_path_in_allowed_dir() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

let result = resolver.resolve("file.txt");
assert!(result.is_ok());
assert!(result.unwrap().ends_with("file.txt"));
}

#[test]
fn resolves_nested_path() {
/// Verifies that valid paths resolve successfully, including both existing
/// files and new files that don't exist yet (important for write operations).
#[rstest]
#[case::existing_file_in_root("file.txt", "file.txt")] // exists: created by setup_test_dir()
#[case::nested_existing_file("subdir/nested.txt", "nested.txt")] // exists: created by setup_test_dir()
#[case::new_file_in_root("new_file.txt", "new_file.txt")] // does NOT exist: tests write path resolution
#[case::new_file_in_subdir("subdir/new_file.txt", "new_file.txt")] // does NOT exist: tests write path resolution
fn resolves_valid_paths_successfully(
#[case] input_path: &str,
#[case] expected_filename: &str,
) {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

let result = resolver.resolve("subdir/nested.txt");
assert!(result.is_ok());
let result = resolver.resolve(input_path);
let resolved = result.expect("path should resolve successfully");
assert!(
resolved.ends_with(expected_filename),
"resolved path should end with '{expected_filename}'"
);
}

#[test]
fn rejects_path_traversal() {
/// Verifies that path traversal attempts are blocked regardless of
/// how the escape is constructed.
#[rstest]
#[case::parent_traversal("../../../etc/passwd")]
#[case::nested_parent_traversal("subdir/../../../new_file.txt")]
fn rejects_paths_that_escape_allowed_directory(#[case] input_path: &str) {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

let result = resolver.resolve("../../../etc/passwd");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not within allowed"));
}

#[test]
fn allows_non_existent_path_for_write() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

let result = resolver.resolve("new_file.txt");
assert!(result.is_ok());
}

#[test]
fn allows_nested_non_existent_path() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

let result = resolver.resolve("subdir/new_file.txt");
assert!(result.is_ok());
}

#[test]
fn rejects_non_existent_path_outside_allowed() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

// Parent traversal in non-existent path
let result = resolver.resolve("subdir/../../../new_file.txt");
assert!(result.is_err());
let result = resolver.resolve(input_path);
let err = result.expect_err("path should be rejected");
assert!(
err.to_string().contains("not within allowed"),
"error should mention 'not within allowed'"
);
}

#[test]
Expand All @@ -212,14 +192,15 @@ mod tests {
}

#[test]
fn returns_canonical_path() {
fn returns_canonical_path_without_dotdots() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();

let result = resolver.resolve("subdir/../file.txt");
assert!(result.is_ok());
// Should resolve to the canonical path without ../
let resolved = result.unwrap();
assert!(!resolved.to_string_lossy().contains(".."));
// Path with ".." should be normalized
let resolved = resolver.resolve("subdir/../file.txt").unwrap();
assert!(
!resolved.to_string_lossy().contains(".."),
"canonical path should not contain '..'"
);
}
}
Loading
Loading