diff --git a/pkg/aiusechat/tools.go b/pkg/aiusechat/tools.go index 02105a9155..8b9ba3f82b 100644 --- a/pkg/aiusechat/tools.go +++ b/pkg/aiusechat/tools.go @@ -127,6 +127,7 @@ func GenerateTabStateAndTools(ctx context.Context, tabid string, widgetAccess bo if widgetAccess { tools = append(tools, GetCaptureScreenshotToolDefinition(tabid)) tools = append(tools, GetReadTextFileToolDefinition()) + tools = append(tools, GetReadDirToolDefinition()) viewTypes := make(map[string]bool) for _, block := range blocks { if block.Meta == nil { diff --git a/pkg/aiusechat/tools_readdir.go b/pkg/aiusechat/tools_readdir.go new file mode 100644 index 0000000000..a0689904e4 --- /dev/null +++ b/pkg/aiusechat/tools_readdir.go @@ -0,0 +1,220 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package aiusechat + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "time" + + "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wavebase" +) + +const ReadDirDefaultMaxEntries = 500 +const ReadDirHardMaxEntries = 10000 + +type readDirParams struct { + Path string `json:"path"` + MaxEntries *int `json:"max_entries"` +} + +type DirEntryOut struct { + Name string `json:"name"` + Dir bool `json:"dir,omitempty"` + Symlink bool `json:"symlink,omitempty"` + Size int64 `json:"size,omitempty"` + Mode string `json:"mode"` + Modified string `json:"modified"` + ModifiedTime string `json:"modified_time"` +} + +func parseReadDirInput(input any) (*readDirParams, error) { + result := &readDirParams{} + + if input == nil { + return nil, fmt.Errorf("input is required") + } + + if err := utilfn.ReUnmarshal(result, input); err != nil { + return nil, fmt.Errorf("invalid input format: %w", err) + } + + if result.Path == "" { + return nil, fmt.Errorf("missing path parameter") + } + + if result.MaxEntries == nil { + maxEntries := ReadDirDefaultMaxEntries + result.MaxEntries = &maxEntries + } + + if *result.MaxEntries < 1 { + return nil, fmt.Errorf("max_entries must be at least 1, got %d", *result.MaxEntries) + } + + if *result.MaxEntries > ReadDirHardMaxEntries { + return nil, fmt.Errorf("max_entries cannot exceed %d, got %d", ReadDirHardMaxEntries, *result.MaxEntries) + } + + return result, nil +} + +func readDirCallback(input any) (any, error) { + params, err := parseReadDirInput(input) + if err != nil { + return nil, err + } + + expandedPath, err := wavebase.ExpandHomeDir(params.Path) + if err != nil { + return nil, fmt.Errorf("failed to expand path: %w", err) + } + + fileInfo, err := os.Stat(expandedPath) + if err != nil { + return nil, fmt.Errorf("failed to stat path: %w", err) + } + + if !fileInfo.IsDir() { + return nil, fmt.Errorf("path is not a directory, cannot be read with the read_dir tool. use the read_text_file tool to read files") + } + + entries, err := os.ReadDir(expandedPath) + if err != nil { + return nil, fmt.Errorf("failed to read directory: %w", err) + } + + // Keep track of the original total before truncation + totalEntries := len(entries) + + // Build a map of actual directory status, checking symlink targets + isDirMap := make(map[string]bool) + symlinkCount := 0 + for _, entry := range entries { + name := entry.Name() + if entry.Type()&fs.ModeSymlink != 0 { + if symlinkCount < 1000 { + symlinkCount++ + fullPath := filepath.Join(expandedPath, name) + if info, err := os.Stat(fullPath); err == nil { + isDirMap[name] = info.IsDir() + } else { + isDirMap[name] = entry.IsDir() + } + } else { + isDirMap[name] = entry.IsDir() + } + } else { + isDirMap[name] = entry.IsDir() + } + } + + // Sort entries: directories first, then files, alphabetically within each group + sort.Slice(entries, func(i, j int) bool { + iIsDir := isDirMap[entries[i].Name()] + jIsDir := isDirMap[entries[j].Name()] + if iIsDir != jIsDir { + return iIsDir + } + return entries[i].Name() < entries[j].Name() + }) + + // Truncate after sorting to ensure directories come first + maxEntries := *params.MaxEntries + var truncated bool + if len(entries) > maxEntries { + entries = entries[:maxEntries] + truncated = true + } + + var entryList []DirEntryOut + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + continue + } + + isDir := isDirMap[entry.Name()] + isSymlink := entry.Type()&fs.ModeSymlink != 0 + + entryData := DirEntryOut{ + Name: entry.Name(), + Dir: isDir, + Symlink: isSymlink, + Mode: info.Mode().String(), + Modified: utilfn.FormatRelativeTime(info.ModTime()), + ModifiedTime: info.ModTime().UTC().Format(time.RFC3339), + } + + if !isDir { + entryData.Size = info.Size() + } + + entryList = append(entryList, entryData) + } + + result := map[string]any{ + "path": params.Path, + "absolute_path": expandedPath, + "entry_count": len(entryList), + "total_entries": totalEntries, + "entries": entryList, + } + + if truncated { + result["truncated"] = true + result["truncated_message"] = fmt.Sprintf("Directory listing truncated to %d entries (out of %d total). Increase max_entries to see more.", len(entryList), totalEntries) + } + + parentDir := filepath.Dir(expandedPath) + if parentDir != expandedPath { + result["parent_dir"] = parentDir + } + + return result, nil +} + +func GetReadDirToolDefinition() uctypes.ToolDefinition { + return uctypes.ToolDefinition{ + Name: "read_dir", + DisplayName: "Read Directory", + Description: "Read a directory from the filesystem and list its contents. Returns information about files and subdirectories including names, types, sizes, permissions, and modification times. Requires user approval.", + ToolLogName: "gen:readdir", + Strict: false, + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "Path to the directory to read", + }, + "max_entries": map[string]any{ + "type": "integer", + "minimum": 1, + "maximum": 10000, + "default": 500, + "description": "Maximum number of entries to return. Defaults to 500, max 10000.", + }, + }, + "required": []string{"path"}, + "additionalProperties": false, + }, + ToolInputDesc: func(input any) string { + parsed, err := parseReadDirInput(input) + if err != nil { + return fmt.Sprintf("error parsing input: %v", err) + } + return fmt.Sprintf("reading directory %q (max_entries: %d)", parsed.Path, *parsed.MaxEntries) + }, + ToolAnyCallback: readDirCallback, + ToolApproval: func(input any) string { + return uctypes.ApprovalNeedsApproval + }, + } +} diff --git a/pkg/aiusechat/tools_readdir_test.go b/pkg/aiusechat/tools_readdir_test.go new file mode 100644 index 0000000000..305c0bfcbd --- /dev/null +++ b/pkg/aiusechat/tools_readdir_test.go @@ -0,0 +1,291 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +package aiusechat + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestReadDirCallback(t *testing.T) { + // Create a temporary test directory + tmpDir, err := os.MkdirTemp("", "readdir_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test files and directories + testFile1 := filepath.Join(tmpDir, "file1.txt") + testFile2 := filepath.Join(tmpDir, "file2.log") + testSubDir := filepath.Join(tmpDir, "subdir") + + if err := os.WriteFile(testFile1, []byte("test content 1"), 0644); err != nil { + t.Fatalf("Failed to create test file 1: %v", err) + } + if err := os.WriteFile(testFile2, []byte("test content 2"), 0644); err != nil { + t.Fatalf("Failed to create test file 2: %v", err) + } + if err := os.Mkdir(testSubDir, 0755); err != nil { + t.Fatalf("Failed to create test subdir: %v", err) + } + + // Test reading the directory + input := map[string]any{ + "path": tmpDir, + } + + result, err := readDirCallback(input) + if err != nil { + t.Fatalf("readDirCallback failed: %v", err) + } + + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Result is not a map") + } + + // Verify the result contains expected fields + if resultMap["path"] != tmpDir { + t.Errorf("Expected path %q, got %q", tmpDir, resultMap["path"]) + } + + entryCount, ok := resultMap["entry_count"].(int) + if !ok { + t.Fatalf("entry_count is not an int") + } + if entryCount != 3 { + t.Errorf("Expected 3 entries, got %d", entryCount) + } + + entries, ok := resultMap["entries"].([]map[string]any) + if !ok { + t.Fatalf("entries is not a slice of maps") + } + + // Check that we have the expected entries + foundFiles := 0 + foundDirs := 0 + for _, entry := range entries { + if entry["is_dir"].(bool) { + foundDirs++ + } else { + foundFiles++ + } + } + + if foundFiles != 2 { + t.Errorf("Expected 2 files, got %d", foundFiles) + } + if foundDirs != 1 { + t.Errorf("Expected 1 directory, got %d", foundDirs) + } +} + +func TestReadDirOnFile(t *testing.T) { + // Create a temporary test file + tmpFile, err := os.CreateTemp("", "readdir_test_file") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + tmpFile.Close() + + // Test reading a file (should fail) + input := map[string]any{ + "path": tmpFile.Name(), + } + + _, err = readDirCallback(input) + if err == nil { + t.Fatalf("Expected error when reading a file with read_dir, got nil") + } + + expectedErrSubstr := "path is not a directory" + if err.Error()[:len(expectedErrSubstr)] != expectedErrSubstr { + t.Errorf("Expected error containing %q, got %q", expectedErrSubstr, err.Error()) + } +} + +func TestReadDirMaxEntries(t *testing.T) { + // Create a temporary test directory with many files + tmpDir, err := os.MkdirTemp("", "readdir_test_max") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create 10 test files + for i := 0; i < 10; i++ { + testFile := filepath.Join(tmpDir, filepath.Base(tmpDir)+string(rune('a'+i))+".txt") + if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + } + + // Test reading with max_entries=5 + maxEntries := 5 + input := map[string]any{ + "path": tmpDir, + "max_entries": maxEntries, + } + + result, err := readDirCallback(input) + if err != nil { + t.Fatalf("readDirCallback failed: %v", err) + } + + resultMap := result.(map[string]any) + entryCount := resultMap["entry_count"].(int) + totalEntries := resultMap["total_entries"].(int) + + if entryCount != maxEntries { + t.Errorf("Expected %d entries, got %d", maxEntries, entryCount) + } + + // Verify total_entries reports the original count, not the truncated count + if totalEntries != 10 { + t.Errorf("Expected total_entries to be 10, got %d", totalEntries) + } + + if _, ok := resultMap["truncated"]; !ok { + t.Error("Expected truncated field to be present") + } + + // Verify the truncation message includes the correct total + truncMsg, ok := resultMap["truncated_message"].(string) + if !ok { + t.Error("Expected truncated_message to be present") + } + expectedMsg := fmt.Sprintf("Directory listing truncated to %d entries (out of %d total)", maxEntries, 10) + if !strings.Contains(truncMsg, expectedMsg[:len(expectedMsg)-1]) { + t.Errorf("Expected truncated_message to contain %q, got %q", expectedMsg, truncMsg) + } +} + +func TestReadDirSortBeforeTruncate(t *testing.T) { + // Create a temporary test directory + tmpDir, err := os.MkdirTemp("", "readdir_test_sort") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create files with names that would sort alphabetically before directories + // but we want directories to appear first + for i := 0; i < 5; i++ { + testFile := filepath.Join(tmpDir, fmt.Sprintf("a_file_%d.txt", i)) + if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + } + + // Create directories with names that sort alphabetically after the files + for i := 0; i < 3; i++ { + testDir := filepath.Join(tmpDir, fmt.Sprintf("z_dir_%d", i)) + if err := os.Mkdir(testDir, 0755); err != nil { + t.Fatalf("Failed to create test dir: %v", err) + } + } + + // Test with max_entries=5 (less than total of 8) + // All 3 directories should still appear because they're sorted first + maxEntries := 5 + input := map[string]any{ + "path": tmpDir, + "max_entries": maxEntries, + } + + result, err := readDirCallback(input) + if err != nil { + t.Fatalf("readDirCallback failed: %v", err) + } + + resultMap := result.(map[string]any) + entries := resultMap["entries"].([]map[string]any) + + // Count directories in the result + dirCount := 0 + for _, entry := range entries { + if entry["is_dir"].(bool) { + dirCount++ + } + } + + // All 3 directories should be present because sorting happens before truncation + if dirCount != 3 { + t.Errorf("Expected 3 directories in truncated results, got %d", dirCount) + } + + // First 3 entries should be directories + for i := 0; i < 3; i++ { + if !entries[i]["is_dir"].(bool) { + t.Errorf("Expected entry %d to be a directory, but it was a file", i) + } + } +} + +func TestParseReadDirInput(t *testing.T) { + // Test valid input + input := map[string]any{ + "path": "/tmp/test", + } + + params, err := parseReadDirInput(input) + if err != nil { + t.Fatalf("parseReadDirInput failed on valid input: %v", err) + } + + if params.Path != "/tmp/test" { + t.Errorf("Expected path '/tmp/test', got %q", params.Path) + } + + if *params.MaxEntries != ReadDirDefaultMaxEntries { + t.Errorf("Expected default max_entries %d, got %d", ReadDirDefaultMaxEntries, *params.MaxEntries) + } + + // Test missing path + input = map[string]any{} + _, err = parseReadDirInput(input) + if err == nil { + t.Error("Expected error for missing path, got nil") + } + + // Test invalid max_entries + input = map[string]any{ + "path": "/tmp/test", + "max_entries": 0, + } + _, err = parseReadDirInput(input) + if err == nil { + t.Error("Expected error for max_entries < 1, got nil") + } +} + +func TestGetReadDirToolDefinition(t *testing.T) { + toolDef := GetReadDirToolDefinition() + + if toolDef.Name != "read_dir" { + t.Errorf("Expected tool name 'read_dir', got %q", toolDef.Name) + } + + if toolDef.ToolLogName != "gen:readdir" { + t.Errorf("Expected tool log name 'gen:readdir', got %q", toolDef.ToolLogName) + } + + if toolDef.ToolAnyCallback == nil { + t.Error("ToolAnyCallback should not be nil") + } + + if toolDef.ToolApproval == nil { + t.Error("ToolApproval should not be nil") + } + + if toolDef.ToolInputDesc == nil { + t.Error("ToolInputDesc should not be nil") + } +} diff --git a/pkg/aiusechat/tools_readfile.go b/pkg/aiusechat/tools_readfile.go index 9f5e38b047..3aab51f730 100644 --- a/pkg/aiusechat/tools_readfile.go +++ b/pkg/aiusechat/tools_readfile.go @@ -8,6 +8,7 @@ import ( "io" "os" "strings" + "time" "github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes" "github.com/wavetermdev/waveterm/pkg/util/readutil" @@ -179,7 +180,7 @@ func readTextFileCallback(input any) (any, error) { "total_size": totalSize, "data": data, "modified": utilfn.FormatRelativeTime(modTime), - "modified_time": modTime.UTC().Format("2006-01-02 15:04:05 UTC"), + "modified_time": modTime.UTC().Format(time.RFC3339), "mode": fileInfo.Mode().String(), } if stopReason != "" {