diff --git a/cmd/wsh/cmd/wshcmd-connserver.go b/cmd/wsh/cmd/wshcmd-connserver.go index 678ea77cc5..d9c7a8bc07 100644 --- a/cmd/wsh/cmd/wshcmd-connserver.go +++ b/cmd/wsh/cmd/wshcmd-connserver.go @@ -21,6 +21,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/util/packetparser" "github.com/wavetermdev/waveterm/pkg/util/sigutil" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" @@ -162,9 +163,7 @@ func serverRunRouter(jwtToken string) error { // just ignore and drain the rawCh (stdin) // when stdin is closed, shutdown defer wshutil.DoShutdown("", 0, true) - for range rawCh { - // ignore - } + utilfn.DrainChannelSafe(rawCh, "serverRunRouter:stdin") }() go func() { for msg := range termProxy.FromRemoteCh { diff --git a/cmd/wsh/cmd/wshcmd-file-util.go b/cmd/wsh/cmd/wshcmd-file-util.go index 432cc1b1fd..811a196c23 100644 --- a/cmd/wsh/cmd/wshcmd-file-util.go +++ b/cmd/wsh/cmd/wshcmd-file-util.go @@ -4,6 +4,7 @@ package cmd import ( + "context" "encoding/base64" "fmt" "io" @@ -11,6 +12,7 @@ import ( "strings" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fsutil" "github.com/wavetermdev/waveterm/pkg/util/fileutil" "github.com/wavetermdev/waveterm/pkg/util/wavefileutil" "github.com/wavetermdev/waveterm/pkg/wshrpc" @@ -27,15 +29,15 @@ func convertNotFoundErr(err error) error { return err } -func ensureFile(origName string, fileData wshrpc.FileData) (*wshrpc.FileInfo, error) { - info, err := wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) +func ensureFile(fileData wshrpc.FileData) (*wshrpc.FileInfo, error) { + info, err := wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) err = convertNotFoundErr(err) if err == fs.ErrNotExist { - err = wshclient.FileCreateCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) + err = wshclient.FileCreateCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { return nil, fmt.Errorf("creating file: %w", err) } - info, err = wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) + info, err = wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { return nil, fmt.Errorf("getting file info: %w", err) } @@ -51,12 +53,12 @@ func streamWriteToFile(fileData wshrpc.FileData, reader io.Reader) error { // First truncate the file with an empty write emptyWrite := fileData emptyWrite.Data64 = "" - err := wshclient.FileWriteCommand(RpcClient, emptyWrite, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) + err := wshclient.FileWriteCommand(RpcClient, emptyWrite, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { return fmt.Errorf("initializing file with empty write: %w", err) } - const chunkSize = 32 * 1024 // 32KB chunks + const chunkSize = wshrpc.FileChunkSize // 32KB chunks buf := make([]byte, chunkSize) totalWritten := int64(0) @@ -89,40 +91,9 @@ func streamWriteToFile(fileData wshrpc.FileData, reader io.Reader) error { return nil } -func streamReadFromFile(fileData wshrpc.FileData, size int64, writer io.Writer) error { - const chunkSize = 32 * 1024 // 32KB chunks - for offset := int64(0); offset < size; offset += chunkSize { - // Calculate the length of this chunk - length := chunkSize - if offset+int64(length) > size { - length = int(size - offset) - } - - // Set up the ReadAt request - fileData.At = &wshrpc.FileDataAt{ - Offset: offset, - Size: length, - } - - // Read the chunk - data, err := wshclient.FileReadCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: int64(fileTimeout)}) - if err != nil { - return fmt.Errorf("reading chunk at offset %d: %w", offset, err) - } - - // Decode and write the chunk - chunk, err := base64.StdEncoding.DecodeString(data.Data64) - if err != nil { - return fmt.Errorf("decoding chunk at offset %d: %w", offset, err) - } - - _, err = writer.Write(chunk) - if err != nil { - return fmt.Errorf("writing chunk at offset %d: %w", offset, err) - } - } - - return nil +func streamReadFromFile(ctx context.Context, fileData wshrpc.FileData, writer io.Writer) error { + ch := wshclient.FileReadStreamCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) + return fsutil.ReadFileStreamToWriter(ctx, ch, writer) } type fileListResult struct { diff --git a/cmd/wsh/cmd/wshcmd-file.go b/cmd/wsh/cmd/wshcmd-file.go index 7c1c60ded0..a0ca112a2a 100644 --- a/cmd/wsh/cmd/wshcmd-file.go +++ b/cmd/wsh/cmd/wshcmd-file.go @@ -9,7 +9,6 @@ import ( "encoding/base64" "fmt" "io" - "io/fs" "log" "os" "path" @@ -31,8 +30,7 @@ const ( WaveFileScheme = "wavefile" WaveFilePrefix = "wavefile://" - DefaultFileTimeout = 5000 - TimeoutYear = int64(365) * 24 * 60 * 60 * 1000 + TimeoutYear = int64(365) * 24 * 60 * 60 * 1000 UriHelpText = ` @@ -83,12 +81,12 @@ Wave Terminal is capable of managing files from remote SSH hosts, S3-compatible systems, and the internal Wave filesystem. Files are addressed via URIs, which vary depending on the storage system.` + UriHelpText} -var fileTimeout int +var fileTimeout int64 func init() { rootCmd.AddCommand(fileCmd) - fileCmd.PersistentFlags().IntVarP(&fileTimeout, "timeout", "t", 15000, "timeout in milliseconds for long operations") + fileCmd.PersistentFlags().Int64VarP(&fileTimeout, "timeout", "t", 15000, "timeout in milliseconds for long operations") fileListCmd.Flags().BoolP("recursive", "r", false, "list subdirectories recursively") fileListCmd.Flags().BoolP("long", "l", false, "use long listing format") @@ -103,7 +101,6 @@ func init() { fileCmd.AddCommand(fileInfoCmd) fileCmd.AddCommand(fileAppendCmd) fileCpCmd.Flags().BoolP("merge", "m", false, "merge directories") - fileCpCmd.Flags().BoolP("recursive", "r", false, "copy directories recursively") fileCpCmd.Flags().BoolP("force", "f", false, "force overwrite of existing files") fileCmd.AddCommand(fileCpCmd) fileMvCmd.Flags().BoolP("recursive", "r", false, "move directories recursively") @@ -174,7 +171,7 @@ var fileAppendCmd = &cobra.Command{ var fileCpCmd = &cobra.Command{ Use: "cp [source-uri] [destination-uri]" + UriHelpText, Aliases: []string{"copy"}, - Short: "copy files between storage systems", + Short: "copy files between storage systems, recursively if needed", Long: "Copy files between different storage systems." + UriHelpText, Example: " wsh file cp wavefile://block/config.txt ./local-config.txt\n wsh file cp ./local-config.txt wavefile://block/config.txt\n wsh file cp wsh://user@ec2/home/user/config.txt wavefile://client/config.txt", Args: cobra.ExactArgs(2), @@ -202,17 +199,7 @@ func fileCatRun(cmd *cobra.Command, args []string) error { Info: &wshrpc.FileInfo{ Path: path}} - // Get file info first to check existence and get size - info, err := wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: 2000}) - err = convertNotFoundErr(err) - if err == fs.ErrNotExist { - return fmt.Errorf("%s: no such file", path) - } - if err != nil { - return fmt.Errorf("getting file info: %w", err) - } - - err = streamReadFromFile(fileData, info.Size, os.Stdout) + err = streamReadFromFile(cmd.Context(), fileData, os.Stdout) if err != nil { return fmt.Errorf("reading file: %w", err) } @@ -229,7 +216,7 @@ func fileInfoRun(cmd *cobra.Command, args []string) error { Info: &wshrpc.FileInfo{ Path: path}} - info, err := wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) + info, err := wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) err = convertNotFoundErr(err) if err != nil { return fmt.Errorf("getting file info: %w", err) @@ -265,20 +252,8 @@ func fileRmRun(cmd *cobra.Command, args []string) error { if err != nil { return err } - fileData := wshrpc.FileData{ - Info: &wshrpc.FileInfo{ - Path: path}} - - _, err = wshclient.FileInfoCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) - err = convertNotFoundErr(err) - if err == fs.ErrNotExist { - return fmt.Errorf("%s: no such file", path) - } - if err != nil { - return fmt.Errorf("getting file info: %w", err) - } - err = wshclient.FileDeleteCommand(RpcClient, wshrpc.CommandDeleteFileData{Path: path, Recursive: recursive}, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) + err = wshclient.FileDeleteCommand(RpcClient, wshrpc.CommandDeleteFileData{Path: path, Recursive: recursive}, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { return fmt.Errorf("removing file: %w", err) } @@ -295,14 +270,31 @@ func fileWriteRun(cmd *cobra.Command, args []string) error { Info: &wshrpc.FileInfo{ Path: path}} - _, err = ensureFile(path, fileData) + capability, err := wshclient.FileShareCapabilityCommand(RpcClient, fileData.Info.Path, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { - return err + return fmt.Errorf("getting fileshare capability: %w", err) } - - err = streamWriteToFile(fileData, WrappedStdin) - if err != nil { - return fmt.Errorf("writing file: %w", err) + if capability.CanAppend { + err = streamWriteToFile(fileData, WrappedStdin) + if err != nil { + return fmt.Errorf("writing file: %w", err) + } + } else { + buf := make([]byte, MaxFileSize) + n, err := WrappedStdin.Read(buf) + if err != nil && err != io.EOF { + return fmt.Errorf("reading input: %w", err) + } + if int64(n) == MaxFileSize { + if _, err := WrappedStdin.Read(make([]byte, 1)); err != io.EOF { + return fmt.Errorf("input exceeds maximum file size of %d bytes", MaxFileSize) + } + } + fileData.Data64 = base64.StdEncoding.EncodeToString(buf[:n]) + err = wshclient.FileWriteCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) + if err != nil { + return fmt.Errorf("writing file: %w", err) + } } return nil @@ -317,7 +309,7 @@ func fileAppendRun(cmd *cobra.Command, args []string) error { Info: &wshrpc.FileInfo{ Path: path}} - info, err := ensureFile(path, fileData) + info, err := ensureFile(fileData) if err != nil { return err } @@ -346,7 +338,7 @@ func fileAppendRun(cmd *cobra.Command, args []string) error { if buf.Len() >= 8192 { // 8KB batch size fileData.Data64 = base64.StdEncoding.EncodeToString(buf.Bytes()) - err = wshclient.FileAppendCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: int64(fileTimeout)}) + err = wshclient.FileAppendCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { return fmt.Errorf("appending to file: %w", err) } @@ -357,7 +349,7 @@ func fileAppendRun(cmd *cobra.Command, args []string) error { if buf.Len() > 0 { fileData.Data64 = base64.StdEncoding.EncodeToString(buf.Bytes()) - err = wshclient.FileAppendCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: int64(fileTimeout)}) + err = wshclient.FileAppendCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: fileTimeout}) if err != nil { return fmt.Errorf("appending to file: %w", err) } @@ -398,10 +390,6 @@ func getTargetPath(src, dst string) (string, error) { func fileCpRun(cmd *cobra.Command, args []string) error { src, dst := args[0], args[1] - recursive, err := cmd.Flags().GetBool("recursive") - if err != nil { - return err - } merge, err := cmd.Flags().GetBool("merge") if err != nil { return err @@ -419,9 +407,9 @@ func fileCpRun(cmd *cobra.Command, args []string) error { if err != nil { return fmt.Errorf("unable to parse dest path: %w", err) } - log.Printf("Copying %s to %s; recursive: %v, merge: %v, force: %v", srcPath, destPath, recursive, merge, force) + log.Printf("Copying %s to %s; merge: %v, force: %v", srcPath, destPath, merge, force) rpcOpts := &wshrpc.RpcOpts{Timeout: TimeoutYear} - err = wshclient.FileCopyCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Recursive: recursive, Merge: merge, Overwrite: force, Timeout: TimeoutYear}}, rpcOpts) + err = wshclient.FileCopyCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Merge: merge, Overwrite: force, Timeout: TimeoutYear}}, rpcOpts) if err != nil { return fmt.Errorf("copying file: %w", err) } @@ -449,7 +437,7 @@ func fileMvRun(cmd *cobra.Command, args []string) error { } log.Printf("Moving %s to %s; recursive: %v, force: %v", srcPath, destPath, recursive, force) rpcOpts := &wshrpc.RpcOpts{Timeout: TimeoutYear} - err = wshclient.FileMoveCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Recursive: recursive, Overwrite: force, Timeout: TimeoutYear}}, rpcOpts) + err = wshclient.FileMoveCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcPath, DestUri: destPath, Opts: &wshrpc.FileCopyOpts{Overwrite: force, Timeout: TimeoutYear, Recursive: recursive}}, rpcOpts) if err != nil { return fmt.Errorf("moving file: %w", err) } @@ -562,10 +550,7 @@ func fileListRun(cmd *cobra.Command, args []string) error { filesChan := wshclient.FileListStreamCommand(RpcClient, wshrpc.FileListData{Path: path, Opts: &wshrpc.FileListOpts{All: recursive}}, &wshrpc.RpcOpts{Timeout: 2000}) // Drain the channel when done - defer func() { - for range filesChan { - } - }() + defer utilfn.DrainChannelSafe(filesChan, "fileListRun") if longForm { return filePrintLong(filesChan) } diff --git a/cmd/wsh/cmd/wshcmd-view.go b/cmd/wsh/cmd/wshcmd-view.go index 97ee8ffdb6..a2f8f86394 100644 --- a/cmd/wsh/cmd/wshcmd-view.go +++ b/cmd/wsh/cmd/wshcmd-view.go @@ -19,6 +19,7 @@ var viewMagnified bool var viewCmd = &cobra.Command{ Use: "view {file|directory|URL}", + Aliases: []string{"preview", "open"}, Short: "preview/edit a file or directory", RunE: viewRun, PreRunE: preRunSetupRpcClient, diff --git a/frontend/app/block/blockframe.tsx b/frontend/app/block/blockframe.tsx index 47410b19c6..c5722b9920 100644 --- a/frontend/app/block/blockframe.tsx +++ b/frontend/app/block/blockframe.tsx @@ -604,7 +604,8 @@ const BlockFrame_Default_Component = (props: BlockFrameProps) => { "--magnified-block-blur": `${magnifiedBlockBlur}px`, } as React.CSSProperties } - {...({ inert: preview ? "1" : undefined } as any)} // sets insert="1" ... but tricks TS into accepting it + // @ts-ignore: inert does exist in the DOM, just not in react + inert={preview ? "1" : undefined} // > {preview || viewModel == null ? null : ( diff --git a/frontend/app/modals/conntypeahead.tsx b/frontend/app/modals/conntypeahead.tsx index 5a9831d062..77df024433 100644 --- a/frontend/app/modals/conntypeahead.tsx +++ b/frontend/app/modals/conntypeahead.tsx @@ -377,13 +377,10 @@ const ChangeConnectionBlockModal = React.memo( // typeahead was opened. good candidate for verbose log level. //console.log("unable to load wsl list from backend. using blank list: ", e) }); - ///////// - // TODO-S3 - // this needs an rpc call to generate a list of s3 profiles - const newS3List = []; - setS3List(newS3List); - ///////// - }, [changeConnModalOpen, setConnList]); + RpcApi.ConnListAWSCommand(TabRpcClient, { timeout: 2000 }) + .then((s3List) => setS3List(s3List ?? [])) + .catch((e) => console.log("unable to load s3 list from backend:", e)); + }, [changeConnModalOpen]); const changeConnection = React.useCallback( async (connName: string) => { @@ -393,10 +390,13 @@ const ChangeConnectionBlockModal = React.memo( if (connName == blockData?.meta?.connection) { return; } + const isAws = connName?.startsWith("aws:"); const oldCwd = blockData?.meta?.file ?? ""; let newCwd: string; if (oldCwd == "") { newCwd = ""; + } else if (isAws) { + newCwd = "/"; } else { newCwd = "~"; } diff --git a/frontend/app/store/global.ts b/frontend/app/store/global.ts index 4b4b6afee4..00a386f223 100644 --- a/frontend/app/store/global.ts +++ b/frontend/app/store/global.ts @@ -672,6 +672,17 @@ function getConnStatusAtom(conn: string): PrimitiveAtom { wshenabled: false, }; rtn = atom(connStatus); + } else if (conn.startsWith("aws:")) { + const connStatus: ConnStatus = { + connection: conn, + connected: true, + error: null, + status: "connected", + hasconnected: true, + activeconnnum: 0, + wshenabled: false, + }; + rtn = atom(connStatus); } else { const connStatus: ConnStatus = { connection: conn, diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index f27f4d0ee9..a73bded771 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -52,6 +52,11 @@ class RpcApiType { return client.wshRpcCall("connlist", null, opts); } + // command "connlistaws" [call] + ConnListAWSCommand(client: WshClient, opts?: RpcOpts): Promise { + return client.wshRpcCall("connlistaws", null, opts); + } + // command "connreinstallwsh" [call] ConnReinstallWshCommand(client: WshClient, data: ConnExtData, opts?: RpcOpts): Promise { return client.wshRpcCall("connreinstallwsh", data, opts); @@ -182,6 +187,11 @@ class RpcApiType { return client.wshRpcCall("fileinfo", data, opts); } + // command "filejoin" [call] + FileJoinCommand(client: WshClient, data: string[], opts?: RpcOpts): Promise { + return client.wshRpcCall("filejoin", data, opts); + } + // command "filelist" [call] FileListCommand(client: WshClient, data: FileListData, opts?: RpcOpts): Promise { return client.wshRpcCall("filelist", data, opts); @@ -207,6 +217,16 @@ class RpcApiType { return client.wshRpcCall("fileread", data, opts); } + // command "filereadstream" [responsestream] + FileReadStreamCommand(client: WshClient, data: FileData, opts?: RpcOpts): AsyncGenerator { + return client.wshRpcStream("filereadstream", data, opts); + } + + // command "filesharecapability" [call] + FileShareCapabilityCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + return client.wshRpcCall("filesharecapability", data, opts); + } + // command "filestreamtar" [responsestream] FileStreamTarCommand(client: WshClient, data: CommandRemoteStreamTarData, opts?: RpcOpts): AsyncGenerator { return client.wshRpcStream("filestreamtar", data, opts); @@ -268,7 +288,7 @@ class RpcApiType { } // command "remotefilecopy" [call] - RemoteFileCopyCommand(client: WshClient, data: CommandRemoteFileCopyData, opts?: RpcOpts): Promise { + RemoteFileCopyCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotefilecopy", data, opts); } @@ -288,7 +308,7 @@ class RpcApiType { } // command "remotefilemove" [call] - RemoteFileMoveCommand(client: WshClient, data: CommandRemoteFileCopyData, opts?: RpcOpts): Promise { + RemoteFileMoveCommand(client: WshClient, data: CommandFileCopyData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotefilemove", data, opts); } diff --git a/frontend/app/suggestion/suggestion.tsx b/frontend/app/suggestion/suggestion.tsx index 84b6340204..9a4e77ede1 100644 --- a/frontend/app/suggestion/suggestion.tsx +++ b/frontend/app/suggestion/suggestion.tsx @@ -232,22 +232,41 @@ const SuggestionControlInner: React.FC = ({ return () => document.removeEventListener("mousedown", handleClickOutside); }, [onClose, anchorRef]); + useEffect(() => { + if (dropdownRef.current) { + const children = dropdownRef.current.children; + if (children[selectedIndex]) { + (children[selectedIndex] as HTMLElement).scrollIntoView({ + behavior: "auto", + block: "nearest", + }); + } + } + }, [selectedIndex]); + const handleKeyDown = (e: React.KeyboardEvent) => { if (e.key === "ArrowDown") { e.preventDefault(); + e.stopPropagation(); setSelectedIndex((prev) => Math.min(prev + 1, suggestions.length - 1)); } else if (e.key === "ArrowUp") { e.preventDefault(); + e.stopPropagation(); setSelectedIndex((prev) => Math.max(prev - 1, 0)); - } else if (e.key === "Enter" && selectedIndex >= 0) { + } else if (e.key === "Enter") { e.preventDefault(); - onSelect(suggestions[selectedIndex], query); - onClose(); + e.stopPropagation(); + if (selectedIndex >= 0 && selectedIndex < suggestions.length) { + onSelect(suggestions[selectedIndex], query); + onClose(); + } } else if (e.key === "Escape") { e.preventDefault(); + e.stopPropagation(); onClose(); } else if (e.key === "Tab") { e.preventDefault(); + e.stopPropagation(); const suggestion = suggestions[selectedIndex]; if (suggestion != null) { const tabResult = onTab?.(suggestion, query); @@ -255,6 +274,14 @@ const SuggestionControlInner: React.FC = ({ setQuery(tabResult); } } + } else if (e.key === "PageDown") { + e.preventDefault(); + e.stopPropagation(); + setSelectedIndex((prev) => Math.min(prev + 10, suggestions.length - 1)); + } else if (e.key === "PageUp") { + e.preventDefault(); + e.stopPropagation(); + setSelectedIndex((prev) => Math.max(prev - 10, 0)); } }; return ( diff --git a/frontend/app/view/preview/directorypreview.tsx b/frontend/app/view/preview/directorypreview.tsx index 845097be3b..907f4de059 100644 --- a/frontend/app/view/preview/directorypreview.tsx +++ b/frontend/app/view/preview/directorypreview.tsx @@ -9,9 +9,9 @@ import { ContextMenuModel } from "@/app/store/contextmenu"; import { PLATFORM, atoms, createBlock, getApi, globalStore } from "@/app/store/global"; import { RpcApi } from "@/app/store/wshclientapi"; import { TabRpcClient } from "@/app/store/wshrpcutil"; -import type { PreviewModel } from "@/app/view/preview/preview"; +import { formatRemoteUri, type PreviewModel } from "@/app/view/preview/preview"; import { checkKeyPressed, isCharacterKeyEvent } from "@/util/keyutil"; -import { fireAndForget, isBlank, makeConnRoute, makeNativeLabel } from "@/util/util"; +import { fireAndForget, isBlank, makeNativeLabel } from "@/util/util"; import { offset, useDismiss, useFloating, useInteractions } from "@floating-ui/react"; import { Column, @@ -528,8 +528,10 @@ function TableBody({ const fileName = finfo.path.split("/").pop(); let parentFileInfo: FileInfo; try { - parentFileInfo = await RpcApi.RemoteFileJoinCommand(TabRpcClient, [normPath, ".."], { - route: makeConnRoute(conn), + parentFileInfo = await RpcApi.FileInfoCommand(TabRpcClient, { + info: { + path: await model.formatRemoteUri(finfo.dir, globalStore.get), + }, }); } catch (e) { console.log("could not get parent file info. using child file info as fallback"); @@ -683,7 +685,6 @@ function TableBody({ setSearch={setSearch} idx={idx} handleFileContextMenu={handleFileContextMenu} - ref={(el) => (rowRefs.current[idx] = el)} key={idx} /> ))} @@ -696,7 +697,6 @@ function TableBody({ setSearch={setSearch} idx={idx + table.getTopRows().length} handleFileContextMenu={handleFileContextMenu} - ref={(el) => (rowRefs.current[idx] = el)} key={idx} /> ))} @@ -715,40 +715,28 @@ type TableRowProps = { handleFileContextMenu: (e: any, finfo: FileInfo) => Promise; }; -const TableRow = React.forwardRef(function ( - { model, row, focusIndex, setFocusIndex, setSearch, idx, handleFileContextMenu }: TableRowProps, - ref: React.RefObject -) { +const TableRow = React.forwardRef(function ({ + model, + row, + focusIndex, + setFocusIndex, + setSearch, + idx, + handleFileContextMenu, +}: TableRowProps) { const dirPath = useAtomValue(model.normFilePath); const connection = useAtomValue(model.connection); - const formatRemoteUri = useCallback( - (path: string) => { - let conn: string; - if (!connection) { - conn = "local"; - } else { - conn = connection; - } - return `wsh://${conn}/${path}`; - }, - [connection] - ); const dragItem: DraggedFile = { relName: row.getValue("name") as string, absParent: dirPath, - uri: formatRemoteUri(row.getValue("path") as string), + uri: formatRemoteUri(row.getValue("path") as string, connection), }; - const [{ isDragging }, drag, dragPreview] = useDrag( + const [_, drag] = useDrag( () => ({ type: "FILE_ITEM", canDrag: true, item: () => dragItem, - collect: (monitor) => { - return { - isDragging: monitor.isDragging(), - }; - }, }), [dragItem] ); diff --git a/frontend/app/view/preview/preview.tsx b/frontend/app/view/preview/preview.tsx index 908fd71c60..d6565f39d3 100644 --- a/frontend/app/view/preview/preview.tsx +++ b/frontend/app/view/preview/preview.tsx @@ -248,7 +248,7 @@ export class PreviewModel implements ViewModel { if (loadableFileInfo.state == "hasData") { headerPath = loadableFileInfo.data?.path; if (headerPath == "~") { - headerPath = `~ (${loadableFileInfo.data?.dir})`; + headerPath = `~ (${loadableFileInfo.data?.dir + "/" + loadableFileInfo.data?.name})`; } } @@ -386,13 +386,7 @@ export class PreviewModel implements ViewModel { }); this.normFilePath = atom>(async (get) => { const fileInfo = await get(this.statFile); - if (fileInfo == null) { - return null; - } - if (fileInfo.isdir) { - return fileInfo.dir + "/"; - } - return fileInfo.dir + "/" + fileInfo.name; + return fileInfo?.path; }); this.loadableStatFilePath = loadable(this.statFilePath); this.connection = atom>(async (get) => { @@ -410,12 +404,14 @@ export class PreviewModel implements ViewModel { }); this.statFile = atom>(async (get) => { const fileName = get(this.metaFilePath); + console.log("stat file", fileName); + const path = await this.formatRemoteUri(fileName, get); if (fileName == null) { return null; } const statFile = await RpcApi.FileInfoCommand(TabRpcClient, { info: { - path: await this.formatRemoteUri(fileName, get), + path, }, }); console.log("stat file", statFile); @@ -431,12 +427,14 @@ export class PreviewModel implements ViewModel { const fullFileAtom = atom>(async (get) => { const fileName = get(this.metaFilePath); + const path = await this.formatRemoteUri(fileName, get); if (fileName == null) { return null; } + console.log("full file path", path); const file = await RpcApi.FileReadCommand(TabRpcClient, { info: { - path: await this.formatRemoteUri(fileName, get), + path, }, }); console.log("full file", file); @@ -446,7 +444,6 @@ export class PreviewModel implements ViewModel { this.fileContentSaved = atom(null) as PrimitiveAtom; const fileContentAtom = atom( async (get) => { - const _ = get(this.metaFilePath); const newContent = get(this.newFileContent); if (newContent != null) { return newContent; @@ -691,21 +688,16 @@ export class PreviewModel implements ViewModel { async handleOpenFile(filePath: string) { const fileInfo = await globalStore.get(this.statFile); + this.updateOpenFileModalAndError(false); if (fileInfo == null) { - this.updateOpenFileModalAndError(false); return true; } - const conn = await globalStore.get(this.connection); try { - const newFileInfo = await RpcApi.RemoteFileJoinCommand(TabRpcClient, [fileInfo.dir, filePath], { - route: makeConnRoute(conn), - }); - this.updateOpenFileModalAndError(false); - this.goHistory(newFileInfo.path); + this.goHistory(filePath); refocusNode(this.blockId); } catch (e) { globalStore.set(this.openFileError, e.message); - console.error("Error opening file", fileInfo.dir, filePath, e); + console.error("Error opening file", filePath, e); } } @@ -724,7 +716,14 @@ export class PreviewModel implements ViewModel { if (filePath == null) { return; } - await navigator.clipboard.writeText(filePath); + const conn = await globalStore.get(this.connection); + if (conn) { + // remote path + await navigator.clipboard.writeText(formatRemoteUri(filePath, conn)); + } else { + // local path + await navigator.clipboard.writeText(filePath); + } }), }); menuItems.push({ @@ -868,8 +867,7 @@ export class PreviewModel implements ViewModel { } async formatRemoteUri(path: string, get: Getter): Promise { - const conn = (await get(this.connection)) ?? "local"; - return `wsh://${conn}/${path}`; + return formatRemoteUri(path, await get(this.connection)); } } @@ -1116,7 +1114,6 @@ const fetchSuggestions = async ( }; function PreviewView({ - blockId, blockRef, contentRef, model, @@ -1304,4 +1301,16 @@ const ErrorOverlay = memo(({ errorMsg, resetOverlay }: { errorMsg: ErrorMsg; res ); }); -export { PreviewView }; +function formatRemoteUri(path: string, connection: string): string { + connection = connection ?? "local"; + // TODO: We need a better way to handle s3 paths + let retVal: string; + if (connection.startsWith("aws:")) { + retVal = `${connection}:s3://${path ?? ""}`; + } else { + retVal = `wsh://${connection}/${path}`; + } + return retVal; +} + +export { formatRemoteUri, PreviewView }; diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index b6e1ca2ba9..2570f4b4de 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -204,13 +204,6 @@ declare global { message: string; }; - // wshrpc.CommandRemoteFileCopyData - type CommandRemoteFileCopyData = { - srcuri: string; - desturi: string; - opts?: FileCopyOpts; - }; - // wshrpc.CommandRemoteListEntriesData type CommandRemoteListEntriesData = { path: string; @@ -460,6 +453,12 @@ declare global { append?: boolean; }; + // wshrpc.FileShareCapability + type FileShareCapability = { + canappend: boolean; + canmkdir: boolean; + }; + // wconfig.FullConfigType type FullConfigType = { settings: SettingsType; diff --git a/pkg/remote/awsconn/awsconn.go b/pkg/remote/awsconn/awsconn.go index ff0deaedaf..5c84532b7f 100644 --- a/pkg/remote/awsconn/awsconn.go +++ b/pkg/remote/awsconn/awsconn.go @@ -17,9 +17,9 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/smithy-go" "github.com/wavetermdev/waveterm/pkg/waveobj" - "github.com/wavetermdev/waveterm/pkg/wconfig" "gopkg.in/ini.v1" ) @@ -44,24 +44,27 @@ func GetConfig(ctx context.Context, profile string) (*aws.Config, error) { } profile = connMatch[1] log.Printf("GetConfig: profile=%s", profile) - profiles, cerrs := wconfig.ReadWaveHomeConfigFile(wconfig.ProfilesFile) - if len(cerrs) > 0 { - return nil, fmt.Errorf("error reading config file: %v", cerrs[0]) - } - if profiles[profile] != nil { - configfilepath, _ := getTempFileFromConfig(profiles, ProfileConfigKey, profile) - credentialsfilepath, _ := getTempFileFromConfig(profiles, ProfileCredentialsKey, profile) - if configfilepath != "" { - log.Printf("configfilepath: %s", configfilepath) - optfns = append(optfns, config.WithSharedConfigFiles([]string{configfilepath})) - tempfiles[profile+"_config"] = configfilepath - } - if credentialsfilepath != "" { - log.Printf("credentialsfilepath: %s", credentialsfilepath) - optfns = append(optfns, config.WithSharedCredentialsFiles([]string{credentialsfilepath})) - tempfiles[profile+"_credentials"] = credentialsfilepath - } - } + + // TODO: Reimplement generic profile support + // profiles, cerrs := wconfig.ReadWaveHomeConfigFile(wconfig.ProfilesFile) + // if len(cerrs) > 0 { + // return nil, fmt.Errorf("error reading config file: %v", cerrs[0]) + // } + // if profiles[profile] != nil { + // configfilepath, _ := getTempFileFromConfig(profiles, ProfileConfigKey, profile) + // credentialsfilepath, _ := getTempFileFromConfig(profiles, ProfileCredentialsKey, profile) + // if configfilepath != "" { + // log.Printf("configfilepath: %s", configfilepath) + // optfns = append(optfns, config.WithSharedConfigFiles([]string{configfilepath})) + // tempfiles[profile+"_config"] = configfilepath + // } + // if credentialsfilepath != "" { + // log.Printf("credentialsfilepath: %s", credentialsfilepath) + // optfns = append(optfns, config.WithSharedCredentialsFiles([]string{credentialsfilepath})) + // tempfiles[profile+"_credentials"] = credentialsfilepath + // } + // } + optfns = append(optfns, config.WithRegion("us-west-2")) trimmedProfile := strings.TrimPrefix(profile, ProfilePrefix) optfns = append(optfns, config.WithSharedConfigProfile(trimmedProfile)) } @@ -112,10 +115,7 @@ func ParseProfiles() map[string]struct{} { f, err = ini.Load(fname) if err != nil { log.Printf("error reading aws credentials file: %v", err) - if profiles == nil { - profiles = make(map[string]struct{}) - } - return profiles + return nil } for _, v := range f.Sections() { profiles[ProfilePrefix+v.Name()] = struct{}{} @@ -124,13 +124,27 @@ func ParseProfiles() map[string]struct{} { } func ListBuckets(ctx context.Context, client *s3.Client) ([]types.Bucket, error) { - output, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) - if err != nil { - var apiErr smithy.APIError - if errors.As(err, &apiErr) { - return nil, fmt.Errorf("error listing buckets: %v", apiErr) + var err error + var output *s3.ListBucketsOutput + var buckets []types.Bucket + bucketPaginator := s3.NewListBucketsPaginator(client, &s3.ListBucketsInput{}) + for bucketPaginator.HasMorePages() { + output, err = bucketPaginator.NextPage(ctx) + if err != nil { + CheckAccessDeniedErr(&err) + return nil, fmt.Errorf("error listing buckets: %v", err) + } else { + buckets = append(buckets, output.Buckets...) } - return nil, fmt.Errorf("error listing buckets: %v", err) } - return output.Buckets, nil + return buckets, nil +} + +func CheckAccessDeniedErr(err *error) bool { + var apiErr smithy.APIError + if err != nil && errors.As(*err, &apiErr) && apiErr.ErrorCode() == "AccessDenied" { + *err = apiErr + return true + } + return false } diff --git a/pkg/remote/connparse/connparse.go b/pkg/remote/connparse/connparse.go index b099d1c0a9..18c4e5e274 100644 --- a/pkg/remote/connparse/connparse.go +++ b/pkg/remote/connparse/connparse.go @@ -47,6 +47,9 @@ func (c *Connection) GetPathWithHost() string { if c.Host == "" { return "" } + if c.Path == "" { + return c.Host + } if strings.HasPrefix(c.Path, "/") { return c.Host + c.Path } @@ -91,12 +94,12 @@ func GetConnNameFromContext(ctx context.Context) (string, error) { // ParseURI parses a connection URI and returns the connection type, host/path, and parameters. func ParseURI(uri string) (*Connection, error) { - split := strings.SplitN(uri, "//", 2) + split := strings.SplitN(uri, "://", 2) var scheme string var rest string if len(split) > 1 { - scheme = strings.TrimSuffix(split[0], ":") - rest = split[1] + scheme = split[0] + rest = strings.TrimPrefix(split[1], "//") } else { rest = split[0] } @@ -107,16 +110,13 @@ func ParseURI(uri string) (*Connection, error) { parseGenericPath := func() { split = strings.SplitN(rest, "/", 2) host = split[0] - if len(split) > 1 { + if len(split) > 1 && split[1] != "" { remotePath = split[1] + } else if strings.HasSuffix(rest, "/") { + // preserve trailing slash + remotePath = "/" } else { - split = strings.SplitN(rest, "/", 2) - host = split[0] - if len(split) > 1 { - remotePath = split[1] - } else { - remotePath = "/" - } + remotePath = "" } } parseWshPath := func() { diff --git a/pkg/remote/connparse/connparse_test.go b/pkg/remote/connparse/connparse_test.go index 82ccc83625..e883ef3fb6 100644 --- a/pkg/remote/connparse/connparse_test.go +++ b/pkg/remote/connparse/connparse_test.go @@ -17,20 +17,20 @@ func TestParseURI_WSHWithScheme(t *testing.T) { } expected := "/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "user@localhost:8080" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "user@localhost:8080/path/to/file" pathWithHost := c.GetPathWithHost() if pathWithHost != expected { - t.Fatalf("expected path with host to be %q, got %q", expected, pathWithHost) + t.Fatalf("expected path with host to be \"%q\", got \"%q\"", expected, pathWithHost) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } if len(c.GetSchemeParts()) != 1 { t.Fatalf("expected scheme parts to be 1, got %d", len(c.GetSchemeParts())) @@ -44,27 +44,27 @@ func TestParseURI_WSHWithScheme(t *testing.T) { } expected = "/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "user@192.168.0.1:22" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "user@192.168.0.1:22/path/to/file" pathWithHost = c.GetPathWithHost() if pathWithHost != expected { - t.Fatalf("expected path with host to be %q, got %q", expected, pathWithHost) + t.Fatalf("expected path with host to be \"%q\", got \"%q\"", expected, pathWithHost) } expected = "wsh" if c.GetType() != expected { - t.Fatalf("expected conn type to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected conn type to be \"%q\", got \"%q\"", expected, c.Scheme) } if len(c.GetSchemeParts()) != 1 { t.Fatalf("expected scheme parts to be 1, got %d", len(c.GetSchemeParts())) } got := c.GetFullURI() if got != cstr { - t.Fatalf("expected full URI to be %q, got %q", cstr, got) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", cstr, got) } } @@ -77,20 +77,20 @@ func TestParseURI_WSHRemoteShorthand(t *testing.T) { if err != nil { t.Fatalf("failed to parse URI: %v", err) } - expected := "/path/to/file" + expected := "path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } if c.Host != "conn" { - t.Fatalf("expected host to be empty, got %q", c.Host) + t.Fatalf("expected host to be empty, got \"%q\"", c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://conn/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } // Test with a complex remote path @@ -99,21 +99,21 @@ func TestParseURI_WSHRemoteShorthand(t *testing.T) { if err != nil { t.Fatalf("failed to parse URI: %v", err) } - expected = "/path/to/file" + expected = "path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "user@localhost:8080" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://user@localhost:8080/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } // Test with an IP address @@ -122,21 +122,21 @@ func TestParseURI_WSHRemoteShorthand(t *testing.T) { if err != nil { t.Fatalf("failed to parse URI: %v", err) } - expected = "/path/to/file" + expected = "path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "user@192.168.0.1:8080" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://user@192.168.0.1:8080/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } @@ -151,19 +151,19 @@ func TestParseURI_WSHCurrentPathShorthand(t *testing.T) { } expected := "~/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "current" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://current/~/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } // Test with a absolute path @@ -174,19 +174,19 @@ func TestParseURI_WSHCurrentPathShorthand(t *testing.T) { } expected = "/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "current" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://current/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } @@ -198,19 +198,19 @@ func TestParseURI_WSHCurrentPath(t *testing.T) { } expected := "./Documents/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "current" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://current/./Documents/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } cstr = "path/to/file" @@ -266,19 +266,19 @@ func TestParseURI_WSHCurrentPathWindows(t *testing.T) { } expected := ".\\Documents\\path\\to\\file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "current" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://current/.\\Documents\\path\\to\\file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } @@ -291,14 +291,14 @@ func TestParseURI_WSHLocalShorthand(t *testing.T) { } expected := "~/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } if c.Host != "local" { - t.Fatalf("expected host to be empty, got %q", c.Host) + t.Fatalf("expected host to be empty, got \"%q\"", c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } cstr = "wsh:///~/path/to/file" @@ -308,18 +308,18 @@ func TestParseURI_WSHLocalShorthand(t *testing.T) { } expected = "~/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } if c.Host != "local" { - t.Fatalf("expected host to be empty, got %q", c.Host) + t.Fatalf("expected host to be empty, got \"%q\"", c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://local/~/path/to/file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } @@ -334,19 +334,19 @@ func TestParseURI_WSHWSL(t *testing.T) { } expected := "/path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "wsl://Ubuntu" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://wsl://Ubuntu/path/to/file" if expected != c.GetFullURI() { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } t.Log("Testing with scheme") @@ -368,19 +368,19 @@ func TestParseUri_LocalWindowsAbsPath(t *testing.T) { } expected := "C:\\path\\to\\file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "local" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://local/C:\\path\\to\\file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } @@ -399,19 +399,19 @@ func TestParseURI_LocalWindowsRelativeShorthand(t *testing.T) { } expected := "~\\path\\to\\file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "local" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "wsh" if c.Scheme != expected { - t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + t.Fatalf("expected scheme to be \"%q\", got \"%q\"", expected, c.Scheme) } expected = "wsh://local/~\\path\\to\\file" if c.GetFullURI() != expected { - t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", expected, c.GetFullURI()) } } @@ -424,22 +424,60 @@ func TestParseURI_BasicS3(t *testing.T) { } expected := "path/to/file" if c.Path != expected { - t.Fatalf("expected path to be %q, got %q", expected, c.Path) + t.Fatalf("expected path to be \"%q\", got \"%q\"", expected, c.Path) } expected = "bucket" if c.Host != expected { - t.Fatalf("expected host to be %q, got %q", expected, c.Host) + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) } expected = "bucket/path/to/file" pathWithHost := c.GetPathWithHost() if pathWithHost != expected { - t.Fatalf("expected path with host to be %q, got %q", expected, pathWithHost) + t.Fatalf("expected path with host to be \"%q\", got \"%q\"", expected, pathWithHost) } expected = "s3" if c.GetType() != expected { - t.Fatalf("expected conn type to be %q, got %q", expected, c.GetType()) + t.Fatalf("expected conn type to be \"%q\", got \"%q\"", expected, c.GetType()) } if len(c.GetSchemeParts()) != 2 { t.Fatalf("expected scheme parts to be 2, got %d", len(c.GetSchemeParts())) } } + +func TestParseURI_S3BucketOnly(t *testing.T) { + t.Parallel() + + testUri := func(cstr string, pathExpected string, pathWithHostExpected string) { + c, err := connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + if c.Path != pathExpected { + t.Fatalf("expected path to be \"%q\", got \"%q\"", pathExpected, c.Path) + } + expected := "bucket" + if c.Host != expected { + t.Fatalf("expected host to be \"%q\", got \"%q\"", expected, c.Host) + } + pathWithHost := c.GetPathWithHost() + if pathWithHost != pathWithHostExpected { + t.Fatalf("expected path with host to be \"%q\", got \"%q\"", expected, pathWithHost) + } + expected = "s3" + if c.GetType() != expected { + t.Fatalf("expected conn type to be \"%q\", got \"%q\"", expected, c.GetType()) + } + if len(c.GetSchemeParts()) != 2 { + t.Fatalf("expected scheme parts to be 2, got %d", len(c.GetSchemeParts())) + } + fullUri := c.GetFullURI() + if fullUri != cstr { + t.Fatalf("expected full URI to be \"%q\", got \"%q\"", cstr, fullUri) + } + } + + t.Log("Testing with no trailing slash") + testUri("profile:s3://bucket", "", "bucket") + t.Log("Testing with trailing slash") + testUri("profile:s3://bucket/", "/", "bucket/") +} diff --git a/pkg/remote/fileshare/fileshare.go b/pkg/remote/fileshare/fileshare.go index 9473db55a2..558da7e551 100644 --- a/pkg/remote/fileshare/fileshare.go +++ b/pkg/remote/fileshare/fileshare.go @@ -5,8 +5,10 @@ import ( "fmt" "log" + "github.com/wavetermdev/waveterm/pkg/remote/awsconn" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/s3fs" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wavefs" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" @@ -28,12 +30,12 @@ func CreateFileShareClient(ctx context.Context, connection string) (fstype.FileS } conntype := conn.GetType() if conntype == connparse.ConnectionTypeS3 { - // config, err := awsconn.GetConfig(ctx, connection) - // if err != nil { - // log.Printf("error getting aws config: %v", err) - // return nil, nil - // } - return nil, nil + config, err := awsconn.GetConfig(ctx, connection) + if err != nil { + log.Printf("error getting aws config: %v", err) + return nil, nil + } + return s3fs.NewS3Client(config), conn } else if conntype == connparse.ConnectionTypeWave { return wavefs.NewWaveClient(), conn } else if conntype == connparse.ConnectionTypeWsh { @@ -45,6 +47,7 @@ func CreateFileShareClient(ctx context.Context, connection string) (fstype.FileS } func Read(ctx context.Context, data wshrpc.FileData) (*wshrpc.FileData, error) { + log.Printf("Read: %v", data.Info.Path) client, conn := CreateFileShareClient(ctx, data.Info.Path) if conn == nil || client == nil { return nil, fmt.Errorf(ErrorParsingConnection, data.Info.Path) @@ -118,11 +121,19 @@ func Move(ctx context.Context, data wshrpc.CommandFileCopyData) error { return fmt.Errorf("error creating fileshare client, could not parse destination connection %s", data.DestUri) } if srcConn.Host != destConn.Host { - err := destClient.CopyRemote(ctx, srcConn, destConn, srcClient, data.Opts) + finfo, err := srcClient.Stat(ctx, srcConn) + if err != nil { + return fmt.Errorf("cannot stat %q: %w", data.SrcUri, err) + } + recursive := data.Opts != nil && data.Opts.Recursive + if finfo.IsDir && data.Opts != nil && !recursive { + return fmt.Errorf("cannot move directory %q to %q without recursive flag", data.SrcUri, data.DestUri) + } + err = destClient.CopyRemote(ctx, srcConn, destConn, srcClient, data.Opts) if err != nil { return fmt.Errorf("cannot copy %q to %q: %w", data.SrcUri, data.DestUri, err) } - return srcClient.Delete(ctx, srcConn, data.Opts.Recursive) + return srcClient.Delete(ctx, srcConn, recursive) } else { return srcClient.MoveInternal(ctx, srcConn, destConn, data.Opts) } @@ -152,10 +163,10 @@ func Delete(ctx context.Context, data wshrpc.CommandDeleteFileData) error { return client.Delete(ctx, conn, data.Recursive) } -func Join(ctx context.Context, path string, parts ...string) (string, error) { +func Join(ctx context.Context, path string, parts ...string) (*wshrpc.FileInfo, error) { client, conn := CreateFileShareClient(ctx, path) if conn == nil || client == nil { - return "", fmt.Errorf(ErrorParsingConnection, path) + return nil, fmt.Errorf(ErrorParsingConnection, path) } return client.Join(ctx, conn, parts...) } @@ -167,3 +178,11 @@ func Append(ctx context.Context, data wshrpc.FileData) error { } return client.AppendFile(ctx, conn, data) } + +func GetCapability(ctx context.Context, path string) (wshrpc.FileShareCapability, error) { + client, conn := CreateFileShareClient(ctx, path) + if conn == nil || client == nil { + return wshrpc.FileShareCapability{}, fmt.Errorf(ErrorParsingConnection, path) + } + return client.GetCapability(), nil +} diff --git a/pkg/remote/fileshare/fspath/fspath.go b/pkg/remote/fileshare/fspath/fspath.go new file mode 100644 index 0000000000..e97ed1230e --- /dev/null +++ b/pkg/remote/fileshare/fspath/fspath.go @@ -0,0 +1,37 @@ +package fspath + +import ( + pathpkg "path" + "strings" +) + +const ( + // Separator is the path separator + Separator = "/" +) + +func Dir(path string) string { + return pathpkg.Dir(ToSlash(path)) +} + +func Base(path string) string { + return pathpkg.Base(ToSlash(path)) +} + +func Join(elem ...string) string { + joined := pathpkg.Join(elem...) + return ToSlash(joined) +} + +// FirstLevelDir returns the first level directory of a path and a boolean indicating if the path has more than one level. +func FirstLevelDir(path string) (string, bool) { + if strings.Count(path, Separator) > 0 { + path = strings.SplitN(path, Separator, 2)[0] + return path, true + } + return path, false +} + +func ToSlash(path string) string { + return strings.ReplaceAll(path, "\\", Separator) +} diff --git a/pkg/remote/fileshare/fstype/fstype.go b/pkg/remote/fileshare/fstype/fstype.go index 3c3d6fceb3..cc67ddeab9 100644 --- a/pkg/remote/fileshare/fstype/fstype.go +++ b/pkg/remote/fileshare/fstype/fstype.go @@ -5,12 +5,20 @@ package fstype import ( "context" + "os" + "time" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) +const ( + DefaultTimeout = 30 * time.Second + FileMode os.FileMode = 0644 + DirMode os.FileMode = 0755 | os.ModeDir +) + type FileShareClient interface { // Stat returns the file info at the given parsed connection path Stat(ctx context.Context, conn *connparse.Connection) (*wshrpc.FileInfo, error) @@ -39,7 +47,9 @@ type FileShareClient interface { // Delete deletes the entry at the given path Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error // Join joins the given parts to the connection path - Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) + Join(ctx context.Context, conn *connparse.Connection, parts ...string) (*wshrpc.FileInfo, error) // GetConnectionType returns the type of connection for the fileshare GetConnectionType() string + // GetCapability returns the capability of the fileshare + GetCapability() wshrpc.FileShareCapability } diff --git a/pkg/remote/fileshare/fsutil/fsutil.go b/pkg/remote/fileshare/fsutil/fsutil.go new file mode 100644 index 0000000000..a6b6660557 --- /dev/null +++ b/pkg/remote/fileshare/fsutil/fsutil.go @@ -0,0 +1,344 @@ +package fsutil + +import ( + "archive/tar" + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "io/fs" + "log" + "strings" + + "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fspath" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/pathtree" + "github.com/wavetermdev/waveterm/pkg/util/tarcopy" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +func GetParentPath(conn *connparse.Connection) string { + hostAndPath := conn.GetPathWithHost() + return GetParentPathString(hostAndPath) +} + +func GetParentPathString(hostAndPath string) string { + if hostAndPath == "" || hostAndPath == fspath.Separator { + return fspath.Separator + } + + // Remove trailing slash if present + if strings.HasSuffix(hostAndPath, fspath.Separator) { + hostAndPath = hostAndPath[:len(hostAndPath)-1] + } + + lastSlash := strings.LastIndex(hostAndPath, fspath.Separator) + if lastSlash <= 0 { + return fspath.Separator + } + return hostAndPath[:lastSlash+1] +} + +const minURILength = 10 // Minimum length for a valid URI (e.g., "s3://bucket") + +func GetPathPrefix(conn *connparse.Connection) string { + fullUri := conn.GetFullURI() + if fullUri == "" { + return "" + } + pathPrefix := fullUri + lastSlash := strings.LastIndex(fullUri, fspath.Separator) + if lastSlash > minURILength && lastSlash < len(fullUri)-1 { + pathPrefix = fullUri[:lastSlash+1] + } + return pathPrefix +} + +func PrefixCopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, c fstype.FileShareClient, opts *wshrpc.FileCopyOpts, listEntriesPrefix func(ctx context.Context, host string, path string) ([]string, error), copyFunc func(ctx context.Context, host string, path string) error) error { + log.Printf("PrefixCopyInternal: %v -> %v", srcConn.GetFullURI(), destConn.GetFullURI()) + merge := opts != nil && opts.Merge + overwrite := opts != nil && opts.Overwrite + if overwrite && merge { + return fmt.Errorf("cannot specify both overwrite and merge") + } + srcHasSlash := strings.HasSuffix(srcConn.Path, fspath.Separator) + srcPath, err := CleanPathPrefix(srcConn.Path) + if err != nil { + return fmt.Errorf("error cleaning source path: %w", err) + } + destHasSlash := strings.HasSuffix(destConn.Path, fspath.Separator) + destPath, err := CleanPathPrefix(destConn.Path) + if err != nil { + return fmt.Errorf("error cleaning destination path: %w", err) + } + if !srcHasSlash { + if !destHasSlash { + destPath += fspath.Separator + } + destPath += fspath.Base(srcPath) + } + destConn.Path = destPath + destInfo, err := c.Stat(ctx, destConn) + destExists := err == nil && !destInfo.NotFound + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("error getting destination file info: %w", err) + } + + srcInfo, err := c.Stat(ctx, srcConn) + if err != nil { + return fmt.Errorf("error getting source file info: %w", err) + } + if destExists { + if overwrite { + err = c.Delete(ctx, destConn, true) + if err != nil { + return fmt.Errorf("error deleting conflicting destination file: %w", err) + } + } else if destInfo.IsDir && srcInfo.IsDir { + if !merge { + return fmt.Errorf("destination and source are both directories, neither merge nor overwrite specified: %v", destConn.GetFullURI()) + } + } else { + return fmt.Errorf("destination already exists, overwrite not specified: %v", destConn.GetFullURI()) + } + } + if srcInfo.IsDir { + if !srcHasSlash { + srcPath += fspath.Separator + } + destPath += fspath.Separator + log.Printf("Copying directory: %v -> %v", srcPath, destPath) + entries, err := listEntriesPrefix(ctx, srcConn.Host, srcPath) + if err != nil { + return fmt.Errorf("error listing source directory: %w", err) + } + + tree := pathtree.NewTree(srcPath, fspath.Separator) + for _, entry := range entries { + tree.Add(entry) + } + + /* tree.Walk will return the full path in the source bucket for each item. + prefixToRemove specifies how much of that path we want in the destination subtree. + If the source path has a trailing slash, we don't want to include the source directory itself in the destination subtree.*/ + prefixToRemove := srcPath + if !srcHasSlash { + prefixToRemove = fspath.Dir(srcPath) + fspath.Separator + } + return tree.Walk(func(path string, numChildren int) error { + // since this is a prefix filesystem, we only care about leafs + if numChildren > 0 { + return nil + } + destFilePath := destPath + strings.TrimPrefix(path, prefixToRemove) + return copyFunc(ctx, path, destFilePath) + }) + } else { + return copyFunc(ctx, srcPath, destPath) + } +} + +func PrefixCopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, srcClient, destClient fstype.FileShareClient, destPutFile func(host string, path string, size int64, reader io.Reader) error, opts *wshrpc.FileCopyOpts) error { + merge := opts != nil && opts.Merge + overwrite := opts != nil && opts.Overwrite + if overwrite && merge { + return fmt.Errorf("cannot specify both overwrite and merge") + } + srcHasSlash := strings.HasSuffix(srcConn.Path, fspath.Separator) + destHasSlash := strings.HasSuffix(destConn.Path, fspath.Separator) + destPath, err := CleanPathPrefix(destConn.Path) + if err != nil { + return fmt.Errorf("error cleaning destination path: %w", err) + } + if !srcHasSlash { + if !destHasSlash { + destPath += fspath.Separator + } + destPath += fspath.Base(srcConn.Path) + } + destConn.Path = destPath + destInfo, err := destClient.Stat(ctx, destConn) + destExists := err == nil && !destInfo.NotFound + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("error getting destination file info: %w", err) + } + + srcInfo, err := srcClient.Stat(ctx, srcConn) + if err != nil { + return fmt.Errorf("error getting source file info: %w", err) + } + if destExists { + if overwrite { + err = destClient.Delete(ctx, destConn, true) + if err != nil { + return fmt.Errorf("error deleting conflicting destination file: %w", err) + } + } else if destInfo.IsDir && srcInfo.IsDir { + if !merge { + return fmt.Errorf("destination and source are both directories, neither merge nor overwrite specified: %v", destConn.GetFullURI()) + } + } else { + return fmt.Errorf("destination already exists, overwrite not specified: %v", destConn.GetFullURI()) + } + } + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return err + } + } + log.Printf("Copying: %v -> %v", srcConn.GetFullURI(), destConn.GetFullURI()) + readCtx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + ioch := srcClient.ReadTarStream(readCtx, srcConn, opts) + err = tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error { + if next.Typeflag == tar.TypeDir { + return nil + } + if singleFile && srcInfo.IsDir { + return fmt.Errorf("protocol error: source is a directory, but only a single file is being copied") + } + fileName, err := CleanPathPrefix(fspath.Join(destPath, next.Name)) + if singleFile && !destHasSlash { + fileName, err = CleanPathPrefix(destConn.Path) + } + if err != nil { + return fmt.Errorf("error cleaning path: %w", err) + } + log.Printf("CopyRemote: writing file: %s; size: %d\n", fileName, next.Size) + return destPutFile(destConn.Host, fileName, next.Size, reader) + }) + if err != nil { + cancel(err) + return err + } + return nil +} + +// CleanPathPrefix corrects paths for prefix filesystems (i.e. ones that don't have directories) +func CleanPathPrefix(path string) (string, error) { + if path == "" { + return "", fmt.Errorf("path is empty") + } + if strings.HasPrefix(path, fspath.Separator) { + path = path[1:] + } + if strings.HasPrefix(path, "~") || strings.HasPrefix(path, ".") || strings.HasPrefix(path, "..") { + return "", fmt.Errorf("path cannot start with ~, ., or ..") + } + var newParts []string + for _, part := range strings.Split(path, fspath.Separator) { + if part == ".." { + if len(newParts) > 0 { + newParts = newParts[:len(newParts)-1] + } + } else if part != "." { + newParts = append(newParts, part) + } + } + return fspath.Join(newParts...), nil +} + +func ReadFileStream(ctx context.Context, readCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData], fileInfoCallback func(finfo wshrpc.FileInfo), dirCallback func(entries []*wshrpc.FileInfo) error, fileCallback func(data io.Reader) error) error { + var fileData *wshrpc.FileData + firstPk := true + isDir := false + drain := true + defer func() { + if drain { + utilfn.DrainChannelSafe(readCh, "ReadFileStream") + } + }() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled: %v", context.Cause(ctx)) + case respUnion, ok := <-readCh: + if !ok { + drain = false + return nil + } + if respUnion.Error != nil { + return respUnion.Error + } + resp := respUnion.Response + if firstPk { + firstPk = false + // first packet has the fileinfo + if resp.Info == nil { + return fmt.Errorf("stream file protocol error, first pk fileinfo is empty") + } + fileData = &resp + if fileData.Info.IsDir { + isDir = true + } + fileInfoCallback(*fileData.Info) + continue + } + if isDir { + if len(resp.Entries) == 0 { + continue + } + if resp.Data64 != "" { + return fmt.Errorf("stream file protocol error, directory entry has data") + } + if err := dirCallback(resp.Entries); err != nil { + return err + } + } else { + if resp.Data64 == "" { + continue + } + decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(resp.Data64))) + if err := fileCallback(decoder); err != nil { + return err + } + } + } + } +} + +func ReadStreamToFileData(ctx context.Context, readCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData]) (*wshrpc.FileData, error) { + var fileData *wshrpc.FileData + var dataBuf bytes.Buffer + var entries []*wshrpc.FileInfo + err := ReadFileStream(ctx, readCh, func(finfo wshrpc.FileInfo) { + fileData = &wshrpc.FileData{ + Info: &finfo, + } + }, func(fileEntries []*wshrpc.FileInfo) error { + entries = append(entries, fileEntries...) + return nil + }, func(data io.Reader) error { + if _, err := io.Copy(&dataBuf, data); err != nil { + return err + } + return nil + }) + if err != nil { + return nil, err + } + if fileData == nil { + return nil, fmt.Errorf("stream file protocol error, no file info") + } + if !fileData.Info.IsDir { + fileData.Data64 = base64.StdEncoding.EncodeToString(dataBuf.Bytes()) + } else { + fileData.Entries = entries + } + return fileData, nil +} + +func ReadFileStreamToWriter(ctx context.Context, readCh <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData], writer io.Writer) error { + return ReadFileStream(ctx, readCh, func(finfo wshrpc.FileInfo) { + }, func(entries []*wshrpc.FileInfo) error { + return nil + }, func(data io.Reader) error { + _, err := io.Copy(writer, data) + return err + }) +} diff --git a/pkg/remote/fileshare/pathtree/pathtree.go b/pkg/remote/fileshare/pathtree/pathtree.go new file mode 100644 index 0000000000..5d4918fbae --- /dev/null +++ b/pkg/remote/fileshare/pathtree/pathtree.go @@ -0,0 +1,128 @@ +package pathtree + +import ( + "log" + "strings" +) + +type WalkFunc func(path string, numChildren int) error + +type Tree struct { + Root *Node + RootPath string + nodes map[string]*Node + delimiter string +} + +type Node struct { + Children map[string]*Node +} + +func (n *Node) Walk(curPath string, walkFunc WalkFunc, delimiter string) error { + if err := walkFunc(curPath, len(n.Children)); err != nil { + return err + } + for name, child := range n.Children { + if err := child.Walk(curPath+delimiter+name, walkFunc, delimiter); err != nil { + return err + } + } + return nil +} + +func NewTree(path string, delimiter string) *Tree { + if len(delimiter) > 1 { + log.Printf("Warning: multi-character delimiter '%s' may cause unexpected behavior", delimiter) + } + if path != "" && !strings.HasSuffix(path, delimiter) { + path += delimiter + } + return &Tree{ + Root: &Node{ + Children: make(map[string]*Node), + }, + nodes: make(map[string]*Node), + RootPath: path, + delimiter: delimiter, + } +} + +func (t *Tree) Add(path string) { + log.Printf("tree.Add: path: %s", path) + // Validate input + if path == "" { + return + } + var relativePath string + if t.RootPath == "" { + relativePath = path + } else { + relativePath = strings.TrimPrefix(path, t.RootPath) + + // If the path is not a child of the root path, ignore it + if relativePath == path { + return + } + + } + + // If the path is already in the tree, ignore it + if t.nodes[relativePath] != nil { + return + } + + components := strings.Split(relativePath, t.delimiter) + // Validate path components + for _, component := range components { + if component == "" || component == "." || component == ".." { + return // Skip invalid paths + } + } + + // Quick check to see if the parent path is already in the tree, in which case we can skip the loop + if parent := t.tryAddToExistingParent(components); parent { + return + } + + t.addNewPath(components) +} + +func (t *Tree) tryAddToExistingParent(components []string) bool { + if len(components) <= 1 { + return false + } + parentPath := strings.Join(components[:len(components)-1], t.delimiter) + if t.nodes[parentPath] == nil { + return false + } + lastPathComponent := components[len(components)-1] + t.nodes[parentPath].Children[lastPathComponent] = &Node{ + Children: make(map[string]*Node), + } + t.nodes[strings.Join(components, t.delimiter)] = t.nodes[parentPath].Children[lastPathComponent] + return true +} + +func (t *Tree) addNewPath(components []string) { + currentNode := t.Root + for i, component := range components { + if _, ok := currentNode.Children[component]; !ok { + currentNode.Children[component] = &Node{ + Children: make(map[string]*Node), + } + curPath := strings.Join(components[:i+1], t.delimiter) + t.nodes[curPath] = currentNode.Children[component] + } + currentNode = currentNode.Children[component] + } +} + +func (t *Tree) Walk(walkFunc WalkFunc) error { + log.Printf("RootPath: %s", t.RootPath) + for key, child := range t.Root.Children { + if err := child.Walk(t.RootPath+key, walkFunc, t.delimiter); err != nil { + return err + } + } + return nil +} diff --git a/pkg/remote/fileshare/pathtree/pathtree_test.go b/pkg/remote/fileshare/pathtree/pathtree_test.go new file mode 100644 index 0000000000..efaa25578e --- /dev/null +++ b/pkg/remote/fileshare/pathtree/pathtree_test.go @@ -0,0 +1,112 @@ +package pathtree_test + +import ( + "errors" + "log" + "testing" + + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/pathtree" +) + +func TestAdd(t *testing.T) { + t.Parallel() + + tree := initializeTree() + + // Check that the tree has the expected structure + if len(tree.Root.Children) != 3 { + t.Errorf("expected 3 children, got %d", len(tree.Root.Children)) + } + + if len(tree.Root.Children["a"].Children) != 3 { + t.Errorf("expected 3 children, got %d", len(tree.Root.Children["a"].Children)) + } + + if len(tree.Root.Children["b"].Children) != 1 { + t.Errorf("expected 1 child, got %d", len(tree.Root.Children["b"].Children)) + } + + if len(tree.Root.Children["b"].Children["g"].Children) != 1 { + t.Errorf("expected 1 child, got %d", len(tree.Root.Children["b"].Children["g"].Children)) + } + + if len(tree.Root.Children["b"].Children["g"].Children["h"].Children) != 0 { + t.Errorf("expected 0 children, got %d", len(tree.Root.Children["b"].Children["g"].Children["h"].Children)) + } + + if len(tree.Root.Children["c"].Children) != 0 { + t.Errorf("expected 0 children, got %d", len(tree.Root.Children["c"].Children)) + } + + // Check that adding the same path again does not change the tree + tree.Add("root/a/d") + if len(tree.Root.Children["a"].Children) != 3 { + t.Errorf("expected 3 children, got %d", len(tree.Root.Children["a"].Children)) + } + + // Check that adding a path that is not a child of the root path does not change the tree + tree.Add("etc/passwd") + if len(tree.Root.Children) != 3 { + t.Errorf("expected 3 children, got %d", len(tree.Root.Children)) + } +} + +func TestWalk(t *testing.T) { + t.Parallel() + + tree := initializeTree() + + // Check that the tree traverses all nodes and identifies leaf nodes correctly + pathMap := make(map[string]int) + err := tree.Walk(func(path string, numChildren int) error { + pathMap[path] = numChildren + return nil + }) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + expectedPathMap := map[string]int{ + "root/a": 3, + "root/a/d": 0, + "root/a/e": 0, + "root/a/f": 0, + "root/b": 1, + "root/b/g": 1, + "root/b/g/h": 0, + "root/c": 0, + } + + log.Printf("pathMap: %v", pathMap) + + for path, numChildren := range expectedPathMap { + if pathMap[path] != numChildren { + t.Errorf("expected %d children for path %s, got %d", numChildren, path, pathMap[path]) + } + } + + expectedError := errors.New("test error") + + // Check that the walk function returns an error if it is returned by the walk function + err = tree.Walk(func(path string, numChildren int) error { + return expectedError + }) + if err != expectedError { + t.Errorf("expected error %v, got %v", expectedError, err) + } +} + +func initializeTree() *pathtree.Tree { + tree := pathtree.NewTree("root/", "/") + tree.Add("root/a") + tree.Add("root/b") + tree.Add("root/c") + tree.Add("root/a/d") + tree.Add("root/a/e") + tree.Add("root/a/f") + tree.Add("root/b/g") + tree.Add("root/b/g/h") + log.Printf("tree: %v", tree) + return tree +} diff --git a/pkg/remote/fileshare/s3fs/s3fs.go b/pkg/remote/fileshare/s3fs/s3fs.go index b406615d4d..6e720d139a 100644 --- a/pkg/remote/fileshare/s3fs/s3fs.go +++ b/pkg/remote/fileshare/s3fs/s3fs.go @@ -4,16 +4,31 @@ package s3fs import ( + "bytes" "context" + "encoding/base64" "errors" + "fmt" + "io" "log" + "strings" + "sync" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/smithy-go" "github.com/wavetermdev/waveterm/pkg/remote/awsconn" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fspath" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fsutil" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/pathtree" + "github.com/wavetermdev/waveterm/pkg/util/fileutil" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/tarcopy" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) @@ -31,94 +46,758 @@ func NewS3Client(config *aws.Config) *S3Client { } func (c S3Client) Read(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) (*wshrpc.FileData, error) { - return nil, errors.ErrUnsupported + rtnCh := c.ReadStream(ctx, conn, data) + return fsutil.ReadStreamToFileData(ctx, rtnCh) } func (c S3Client) ReadStream(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { - return wshutil.SendErrCh[wshrpc.FileData](errors.ErrUnsupported) + bucket := conn.Host + objectKey := conn.Path + log.Printf("s3fs.ReadStream: %v", conn.GetFullURI()) + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.FileData], 16) + go func() { + defer close(rtn) + finfo, err := c.Stat(ctx, conn) + if err != nil { + rtn <- wshutil.RespErr[wshrpc.FileData](err) + return + } + rtn <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: wshrpc.FileData{Info: finfo}} + if finfo.IsDir { + listEntriesCh := c.ListEntriesStream(ctx, conn, nil) + defer func() { + utilfn.DrainChannelSafe(listEntriesCh, "s3fs.ReadStream") + }() + for respUnion := range listEntriesCh { + if respUnion.Error != nil { + rtn <- wshutil.RespErr[wshrpc.FileData](respUnion.Error) + return + } + resp := respUnion.Response + if len(resp.FileInfo) > 0 { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: wshrpc.FileData{Entries: resp.FileInfo}} + } + } + } else { + var result *s3.GetObjectOutput + var err error + if data.At != nil { + log.Printf("reading %v with offset %d and size %d", conn.GetFullURI(), data.At.Offset, data.At.Size) + result, err = c.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(objectKey), + Range: aws.String(fmt.Sprintf("bytes=%d-%d", data.At.Offset, data.At.Offset+int64(data.At.Size)-1)), + }) + } else { + log.Printf("reading %v", conn.GetFullURI()) + result, err = c.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(objectKey), + }) + } + if err != nil { + log.Printf("error getting object %v:%v: %v", bucket, objectKey, err) + var noKey *types.NoSuchKey + if errors.As(err, &noKey) { + err = noKey + } + rtn <- wshutil.RespErr[wshrpc.FileData](err) + return + } + size := int64(0) + if result.ContentLength != nil { + size = *result.ContentLength + } + finfo := &wshrpc.FileInfo{ + Name: objectKey, + IsDir: false, + Size: size, + ModTime: result.LastModified.UnixMilli(), + Path: conn.GetFullURI(), + Dir: fsutil.GetParentPath(conn), + } + fileutil.AddMimeTypeToFileInfo(finfo.Path, finfo) + log.Printf("file info: %v", finfo) + rtn <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: wshrpc.FileData{Info: finfo}} + if size == 0 { + log.Printf("no data to read") + return + } + defer utilfn.GracefulClose(result.Body, "s3fs", conn.GetFullURI()) + bytesRemaining := size + for { + log.Printf("bytes remaining: %d", bytesRemaining) + select { + case <-ctx.Done(): + log.Printf("context done") + rtn <- wshutil.RespErr[wshrpc.FileData](context.Cause(ctx)) + return + default: + buf := make([]byte, min(bytesRemaining, wshrpc.FileChunkSize)) + n, err := result.Body.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + rtn <- wshutil.RespErr[wshrpc.FileData](err) + return + } + log.Printf("read %d bytes", n) + if n == 0 { + break + } + bytesRemaining -= int64(n) + rtn <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: wshrpc.FileData{Data64: base64.StdEncoding.EncodeToString(buf[:n])}} + if bytesRemaining == 0 || errors.Is(err, io.EOF) { + return + } + } + } + } + }() + return rtn } func (c S3Client) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { - return wshutil.SendErrCh[iochantypes.Packet](errors.ErrUnsupported) -} + bucket := conn.Host + if bucket == "" || bucket == "/" { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("bucket must be specified")) + } -func (c S3Client) ListEntriesStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) <-chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] { - ch := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16) - go func() { - defer close(ch) - list, err := c.ListEntries(ctx, conn, opts) + // whether the operation is on the whole bucket + wholeBucket := conn.Path == "" || conn.Path == fspath.Separator + + // get the object if it's a single file operation + var singleFileResult *s3.GetObjectOutput + // this ensures we don't leak the object if we error out before copying it + closeSingleFileResult := true + defer func() { + // in case we error out before the object gets copied, make sure to close it + if singleFileResult != nil && closeSingleFileResult { + utilfn.GracefulClose(singleFileResult.Body, "s3fs", conn.Path) + } + }() + var err error + if !wholeBucket { + singleFileResult, err = c.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(conn.Path), // does not care if the path has a prefixed slash + }) if err != nil { - ch <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err) - return + // if the object doesn't exist, we can assume the prefix is a directory and continue + var noKey *types.NoSuchKey + var notFound *types.NotFound + if !errors.As(err, &noKey) && !errors.As(err, ¬Found) { + return wshutil.SendErrCh[iochantypes.Packet](err) + } } - if list == nil { - ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: wshrpc.CommandRemoteListEntriesRtnData{}} - return + } + + // whether the operation is on a single file + singleFile := singleFileResult != nil + + // whether to include the directory itself in the tar + includeDir := (wholeBucket && conn.Path == "") || (singleFileResult == nil && conn.Path != "" && !strings.HasSuffix(conn.Path, fspath.Separator)) + + timeout := fstype.DefaultTimeout + if opts.Timeout > 0 { + timeout = time.Duration(opts.Timeout) * time.Millisecond + } + readerCtx, cancel := context.WithTimeout(context.Background(), timeout) + + // the prefix that should be removed from the tar paths + tarPathPrefix := conn.Path + + if wholeBucket { + // we treat the bucket name as the root directory. If we're not including the directory itself, we need to remove the bucket name from the tar paths + if includeDir { + tarPathPrefix = "" + } else { + tarPathPrefix = bucket } - for i := 0; i < len(list); i += wshrpc.DirChunkSize { - ch <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: wshrpc.CommandRemoteListEntriesRtnData{FileInfo: list[i:min(i+wshrpc.DirChunkSize, len(list))]}} + } else if singleFile || includeDir { + // if we're including the directory itself, we need to remove the last part of the path + tarPathPrefix = fsutil.GetParentPathString(tarPathPrefix) + } + + rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, tarPathPrefix) + go func() { + defer func() { + tarClose() + cancel() + }() + + // below we get the objects concurrently so we need to store the results in a map + objMap := make(map[string]*s3.GetObjectOutput) + // close the objects when we're done + defer func() { + for key, obj := range objMap { + log.Printf("closing object %v", key) + utilfn.GracefulClose(obj.Body, "s3fs", key) + } + }() + + // tree to keep track of the paths we've added and insert fake directories for subpaths + tree := pathtree.NewTree(tarPathPrefix, "/") + + if singleFile { + objMap[conn.Path] = singleFileResult + tree.Add(conn.Path) + } else { + // list the objects in the bucket and add them to a tree that we can then walk to write the tar entries + var input *s3.ListObjectsV2Input + if wholeBucket { + // get all the objects in the bucket + input = &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + } + } else { + objectPrefix := conn.Path + if !strings.HasSuffix(objectPrefix, fspath.Separator) { + objectPrefix = objectPrefix + fspath.Separator + } + input = &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(objectPrefix), + } + } + + errs := make([]error, 0) + // mutex to protect the tree and objMap since we're fetching objects concurrently + treeMapMutex := sync.Mutex{} + // wait group to await the finished fetches + wg := sync.WaitGroup{} + getObjectAndFileInfo := func(obj *types.Object) { + defer wg.Done() + result, err := c.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: obj.Key, + }) + if err != nil { + errs = append(errs, err) + return + } + path := *obj.Key + if wholeBucket { + path = fspath.Join(bucket, path) + } + treeMapMutex.Lock() + defer treeMapMutex.Unlock() + objMap[path] = result + tree.Add(path) + } + + if err := c.listFilesPrefix(ctx, input, func(obj *types.Object) (bool, error) { + wg.Add(1) + go getObjectAndFileInfo(obj) + return true, nil + }); err != nil { + rtn <- wshutil.RespErr[iochantypes.Packet](err) + return + } + wg.Wait() + if len(errs) > 0 { + rtn <- wshutil.RespErr[iochantypes.Packet](errors.Join(errs...)) + return + } + } + + // Walk the tree and write the tar entries + if err := tree.Walk(func(path string, numChildren int) error { + mapEntry, isFile := objMap[path] + + // default vals assume entry is dir, since mapEntry might not exist + modTime := int64(time.Now().Unix()) + mode := fstype.DirMode + size := int64(numChildren) + + if isFile { + mode = fstype.FileMode + size = *mapEntry.ContentLength + if mapEntry.LastModified != nil { + modTime = mapEntry.LastModified.UnixMilli() + } + } + + finfo := &wshrpc.FileInfo{ + Name: path, + IsDir: !isFile, + Size: size, + ModTime: modTime, + Mode: mode, + } + if err := writeHeader(fileutil.ToFsFileInfo(finfo), path, singleFile); err != nil { + return err + } + if isFile { + if n, err := io.Copy(fileWriter, mapEntry.Body); err != nil { + return err + } else if n != size { + return fmt.Errorf("error copying %v; expected to read %d bytes, but read %d", path, size, n) + } + } + return nil + }); err != nil { + log.Printf("error walking tree: %v", err) + rtn <- wshutil.RespErr[iochantypes.Packet](err) + return } }() - return ch + // we've handed singleFileResult off to the tar writer, so we don't want to close it + closeSingleFileResult = false + return rtn } func (c S3Client) ListEntries(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) ([]*wshrpc.FileInfo, error) { - if conn.Path == "" || conn.Path == "/" { + var entries []*wshrpc.FileInfo + rtnCh := c.ListEntriesStream(ctx, conn, opts) + for respUnion := range rtnCh { + if respUnion.Error != nil { + return nil, respUnion.Error + } + resp := respUnion.Response + entries = append(entries, resp.FileInfo...) + } + return entries, nil +} + +func (c S3Client) ListEntriesStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) <-chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] { + bucket := conn.Host + objectKeyPrefix := conn.Path + if objectKeyPrefix != "" && !strings.HasSuffix(objectKeyPrefix, fspath.Separator) { + objectKeyPrefix = objectKeyPrefix + "/" + } + numToFetch := wshrpc.MaxDirSize + if opts != nil && opts.Limit > 0 { + numToFetch = min(opts.Limit, wshrpc.MaxDirSize) + } + numFetched := 0 + if bucket == "" || bucket == fspath.Separator { buckets, err := awsconn.ListBuckets(ctx, c.client) if err != nil { - return nil, err + return wshutil.SendErrCh[wshrpc.CommandRemoteListEntriesRtnData](err) } var entries []*wshrpc.FileInfo for _, bucket := range buckets { - log.Printf("bucket: %v", *bucket.Name) + if numFetched >= numToFetch { + break + } if bucket.Name != nil { entries = append(entries, &wshrpc.FileInfo{ - Path: *bucket.Name, - IsDir: true, + Path: *bucket.Name, + Name: *bucket.Name, + Dir: fspath.Separator, + ModTime: bucket.CreationDate.UnixMilli(), + IsDir: true, + MimeType: "directory", }) + numFetched++ } } - return entries, nil + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 1) + defer close(rtn) + rtn <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: wshrpc.CommandRemoteListEntriesRtnData{FileInfo: entries}} + return rtn + } else { + rtn := make(chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData], 16) + // keep track of "directories" that have been used to avoid duplicates between pages + prevUsedDirKeys := make(map[string]any) + go func() { + defer close(rtn) + entryMap := make(map[string]*wshrpc.FileInfo) + if err := c.listFilesPrefix(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(objectKeyPrefix), + }, func(obj *types.Object) (bool, error) { + if numFetched >= numToFetch { + return false, nil + } + lastModTime := int64(0) + if obj.LastModified != nil { + lastModTime = obj.LastModified.UnixMilli() + } + // get the first level directory name or file name + name, isDir := fspath.FirstLevelDir(strings.TrimPrefix(*obj.Key, objectKeyPrefix)) + path := fspath.Join(conn.GetPathWithHost(), name) + if isDir { + if entryMap[name] == nil { + if _, ok := prevUsedDirKeys[name]; !ok { + entryMap[name] = &wshrpc.FileInfo{ + Path: path, + Name: name, + IsDir: true, + Dir: objectKeyPrefix, + ModTime: lastModTime, + Size: 0, + } + fileutil.AddMimeTypeToFileInfo(path, entryMap[name]) + + prevUsedDirKeys[name] = struct{}{} + numFetched++ + } + } else if entryMap[name].ModTime < lastModTime { + entryMap[name].ModTime = lastModTime + } + return true, nil + } + + size := int64(0) + if obj.Size != nil { + size = *obj.Size + } + entryMap[name] = &wshrpc.FileInfo{ + Name: name, + IsDir: false, + Dir: objectKeyPrefix, + Path: path, + ModTime: lastModTime, + Size: size, + } + fileutil.AddMimeTypeToFileInfo(path, entryMap[name]) + numFetched++ + return true, nil + }); err != nil { + rtn <- wshutil.RespErr[wshrpc.CommandRemoteListEntriesRtnData](err) + return + } + parentPath := fsutil.GetParentPath(conn) + if parentPath != "" { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: wshrpc.CommandRemoteListEntriesRtnData{FileInfo: []*wshrpc.FileInfo{ + { + Path: parentPath, + Dir: fsutil.GetParentPathString(parentPath), + Name: "..", + IsDir: true, + Size: 0, + ModTime: time.Now().Unix(), + MimeType: "directory", + }, + }}} + } + entries := make([]*wshrpc.FileInfo, 0, wshrpc.DirChunkSize) + for _, entry := range entryMap { + entries = append(entries, entry) + if len(entries) == wshrpc.DirChunkSize { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: wshrpc.CommandRemoteListEntriesRtnData{FileInfo: entries}} + entries = make([]*wshrpc.FileInfo, 0, wshrpc.DirChunkSize) + } + } + if len(entries) > 0 { + rtn <- wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData]{Response: wshrpc.CommandRemoteListEntriesRtnData{FileInfo: entries}} + } + }() + return rtn } - return nil, nil } func (c S3Client) Stat(ctx context.Context, conn *connparse.Connection) (*wshrpc.FileInfo, error) { - return nil, errors.ErrUnsupported + log.Printf("Stat: %v", conn.GetFullURI()) + bucketName := conn.Host + objectKey := conn.Path + if bucketName == "" || bucketName == fspath.Separator { + // root, refers to list all buckets + return &wshrpc.FileInfo{ + Name: fspath.Separator, + IsDir: true, + Size: 0, + ModTime: 0, + Path: fspath.Separator, + Dir: fspath.Separator, + MimeType: "directory", + }, nil + } + if objectKey == "" || objectKey == fspath.Separator { + _, err := c.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(bucketName), + }) + exists := true + if err != nil { + var apiError smithy.APIError + if errors.As(err, &apiError) { + switch apiError.(type) { + case *types.NotFound: + exists = false + default: + } + } + } + + if exists { + return &wshrpc.FileInfo{ + Name: bucketName, + Path: bucketName, + Dir: fspath.Separator, + IsDir: true, + Size: 0, + ModTime: 0, + MimeType: "directory", + }, nil + } else { + return &wshrpc.FileInfo{ + Name: bucketName, + Path: bucketName, + Dir: fspath.Separator, + NotFound: true, + }, nil + } + } + result, err := c.client.GetObjectAttributes(ctx, &s3.GetObjectAttributesInput{ + Bucket: aws.String(bucketName), + Key: aws.String(objectKey), + ObjectAttributes: []types.ObjectAttributes{ + types.ObjectAttributesObjectSize, + }, + }) + if err != nil { + var noKey *types.NoSuchKey + var notFound *types.NotFound + if errors.As(err, &noKey) || errors.As(err, ¬Found) { + // try to list a single object to see if the prefix exists + if !strings.HasSuffix(objectKey, fspath.Separator) { + objectKey += fspath.Separator + } + entries, err := c.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucketName), + Prefix: aws.String(objectKey), + MaxKeys: aws.Int32(1), + }) + if err == nil { + if entries.Contents != nil && len(entries.Contents) > 0 { + return &wshrpc.FileInfo{ + Name: objectKey, + Path: conn.GetPathWithHost(), + Dir: fsutil.GetParentPath(conn), + IsDir: true, + Size: 0, + Mode: fstype.DirMode, + MimeType: "directory", + }, nil + } + } else if !errors.As(err, &noKey) && !errors.As(err, ¬Found) { + return nil, err + } + + return &wshrpc.FileInfo{ + Name: objectKey, + Path: conn.GetPathWithHost(), + Dir: fsutil.GetParentPath(conn), + NotFound: true, + }, nil + } + return nil, err + } + size := int64(0) + if result.ObjectSize != nil { + size = *result.ObjectSize + } + lastModified := int64(0) + if result.LastModified != nil { + lastModified = result.LastModified.UnixMilli() + } + rtn := &wshrpc.FileInfo{ + Name: objectKey, + Path: conn.GetPathWithHost(), + Dir: fsutil.GetParentPath(conn), + IsDir: false, + Size: size, + ModTime: lastModified, + } + fileutil.AddMimeTypeToFileInfo(rtn.Path, rtn) + return rtn, nil } func (c S3Client) PutFile(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) error { - return errors.ErrUnsupported + log.Printf("PutFile: %v", conn.GetFullURI()) + if data.At != nil { + log.Printf("PutFile: offset %d and size %d", data.At.Offset, data.At.Size) + return errors.Join(errors.ErrUnsupported, fmt.Errorf("file data offset and size not supported")) + } + bucket := conn.Host + objectKey := conn.Path + if bucket == "" || bucket == "/" || objectKey == "" || objectKey == "/" { + log.Printf("PutFile: bucket and object key must be specified") + return errors.Join(errors.ErrUnsupported, fmt.Errorf("bucket and object key must be specified")) + } + contentMaxLength := base64.StdEncoding.DecodedLen(len(data.Data64)) + var decodedBody []byte + var contentLength int + var err error + if contentMaxLength > 0 { + decodedBody = make([]byte, contentMaxLength) + contentLength, err = base64.StdEncoding.Decode(decodedBody, []byte(data.Data64)) + if err != nil { + log.Printf("PutFile: error decoding data: %v", err) + return err + } + } else { + decodedBody = []byte("\n") + contentLength = 1 + } + bodyReaderSeeker := bytes.NewReader(decodedBody[:contentLength]) + _, err = c.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(objectKey), + Body: bodyReaderSeeker, + ContentLength: aws.Int64(int64(contentLength)), + }) + if err != nil { + log.Printf("PutFile: error putting object %v:%v: %v", bucket, objectKey, err) + } + return err } func (c S3Client) AppendFile(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) error { - return errors.ErrUnsupported + return errors.Join(errors.ErrUnsupported, fmt.Errorf("append file not supported")) } func (c S3Client) Mkdir(ctx context.Context, conn *connparse.Connection) error { - return errors.ErrUnsupported + return errors.Join(errors.ErrUnsupported, fmt.Errorf("mkdir not supported")) } func (c S3Client) MoveInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { - return errors.ErrUnsupported + err := c.CopyInternal(ctx, srcConn, destConn, opts) + if err != nil { + return err + } + return c.Delete(ctx, srcConn, true) } func (c S3Client) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, srcClient fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { - return errors.ErrUnsupported + if srcConn.Scheme == connparse.ConnectionTypeS3 && destConn.Scheme == connparse.ConnectionTypeS3 { + return c.CopyInternal(ctx, srcConn, destConn, opts) + } + destBucket := destConn.Host + if destBucket == "" || destBucket == fspath.Separator { + return fmt.Errorf("destination bucket must be specified") + } + return fsutil.PrefixCopyRemote(ctx, srcConn, destConn, srcClient, c, func(bucket, path string, size int64, reader io.Reader) error { + _, err := c.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(path), + Body: reader, + ContentLength: aws.Int64(size), + }) + return err + }, opts) } func (c S3Client) CopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { - return errors.ErrUnsupported + srcBucket := srcConn.Host + destBucket := destConn.Host + if srcBucket == "" || srcBucket == fspath.Separator || destBucket == "" || destBucket == fspath.Separator { + return fmt.Errorf("source and destination bucket must be specified") + } + return fsutil.PrefixCopyInternal(ctx, srcConn, destConn, c, opts, func(ctx context.Context, bucket, prefix string) ([]string, error) { + var entries []string + err := c.listFilesPrefix(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + }, func(obj *types.Object) (bool, error) { + entries = append(entries, *obj.Key) + return true, nil + }) + return entries, err + }, func(ctx context.Context, srcPath, destPath string) error { + log.Printf("Copying file %v -> %v", srcBucket+"/"+srcPath, destBucket+"/"+destPath) + _, err := c.client.CopyObject(ctx, &s3.CopyObjectInput{ + Bucket: aws.String(destBucket), + Key: aws.String(destPath), + CopySource: aws.String(fspath.Join(srcBucket, srcPath)), + }) + return err + }) +} + +func (c S3Client) listFilesPrefix(ctx context.Context, input *s3.ListObjectsV2Input, fileCallback func(*types.Object) (bool, error)) error { + var err error + var output *s3.ListObjectsV2Output + objectPaginator := s3.NewListObjectsV2Paginator(c.client, input) + for objectPaginator.HasMorePages() { + output, err = objectPaginator.NextPage(ctx) + if err != nil { + var noBucket *types.NoSuchBucket + if !awsconn.CheckAccessDeniedErr(&err) && errors.As(err, &noBucket) { + err = noBucket + } + return err + } else { + for _, obj := range output.Contents { + if cont, err := fileCallback(&obj); err != nil { + return err + } else if !cont { + return nil + } + } + } + } + return nil } func (c S3Client) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { - return errors.ErrUnsupported + bucket := conn.Host + objectKey := conn.Path + if bucket == "" || bucket == fspath.Separator { + return errors.Join(errors.ErrUnsupported, fmt.Errorf("bucket must be specified")) + } + if objectKey == "" || objectKey == fspath.Separator { + return errors.Join(errors.ErrUnsupported, fmt.Errorf("object key must be specified")) + } + if recursive { + if !strings.HasSuffix(objectKey, fspath.Separator) { + objectKey = objectKey + fspath.Separator + } + entries, err := c.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(objectKey), + }) + if err != nil { + return err + } + if len(entries.Contents) == 0 { + return nil + } + objects := make([]types.ObjectIdentifier, 0, len(entries.Contents)) + for _, obj := range entries.Contents { + objects = append(objects, types.ObjectIdentifier{Key: obj.Key}) + } + _, err = c.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{ + Bucket: aws.String(bucket), + Delete: &types.Delete{ + Objects: objects, + }, + }) + return err + } + _, err := c.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(objectKey), + }) + return err } -func (c S3Client) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) { - return "", errors.ErrUnsupported +func (c S3Client) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (*wshrpc.FileInfo, error) { + var joinParts []string + if conn.Host == "" || conn.Host == fspath.Separator { + if conn.Path == "" || conn.Path == fspath.Separator { + joinParts = parts + } else { + joinParts = append([]string{conn.Path}, parts...) + } + } else if conn.Path == "" || conn.Path == "/" { + joinParts = append([]string{conn.Host}, parts...) + } else { + joinParts = append([]string{conn.Host, conn.Path}, parts...) + } + + conn.Path = fspath.Join(joinParts...) + + return c.Stat(ctx, conn) } func (c S3Client) GetConnectionType() string { return connparse.ConnectionTypeS3 } + +func (c S3Client) GetCapability() wshrpc.FileShareCapability { + return wshrpc.FileShareCapability{ + CanAppend: false, + CanMkdir: false, + } +} diff --git a/pkg/remote/fileshare/wavefs/wavefs.go b/pkg/remote/fileshare/wavefs/wavefs.go index 63cbe36a1d..b30c4bad39 100644 --- a/pkg/remote/fileshare/wavefs/wavefs.go +++ b/pkg/remote/fileshare/wavefs/wavefs.go @@ -4,7 +4,6 @@ package wavefs import ( - "archive/tar" "context" "encoding/base64" "errors" @@ -12,13 +11,16 @@ import ( "io" "io/fs" "log" - "path" + "os" + "path/filepath" "strings" "time" "github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fspath" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fsutil" "github.com/wavetermdev/waveterm/pkg/util/fileutil" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/util/tarcopy" @@ -30,7 +32,7 @@ import ( ) const ( - DefaultTimeout = 30 * time.Second + DirMode os.FileMode = 0755 | os.ModeDir ) type WaveClient struct{} @@ -54,7 +56,7 @@ func (c WaveClient) ReadStream(ctx context.Context, conn *connparse.Connection, if !rtnData.Info.IsDir { for i := 0; i < dataLen; i += wshrpc.FileChunkSize { if ctx.Err() != nil { - ch <- wshutil.RespErr[wshrpc.FileData](ctx.Err()) + ch <- wshutil.RespErr[wshrpc.FileData](context.Cause(ctx)) return } dataEnd := min(i+wshrpc.FileChunkSize, dataLen) @@ -63,7 +65,7 @@ func (c WaveClient) ReadStream(ctx context.Context, conn *connparse.Connection, } else { for i := 0; i < len(rtnData.Entries); i += wshrpc.DirChunkSize { if ctx.Err() != nil { - ch <- wshutil.RespErr[wshrpc.FileData](ctx.Err()) + ch <- wshutil.RespErr[wshrpc.FileData](context.Cause(ctx)) return } ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: wshrpc.FileData{Entries: rtnData.Entries[i:min(i+wshrpc.DirChunkSize, len(rtnData.Entries))], Info: rtnData.Info}} @@ -108,15 +110,42 @@ func (c WaveClient) Read(ctx context.Context, conn *connparse.Connection, data w func (c WaveClient) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { log.Printf("ReadTarStream: conn: %v, opts: %v\n", conn, opts) - list, err := c.ListEntries(ctx, conn, nil) + path := conn.Path + srcHasSlash := strings.HasSuffix(path, "/") + cleanedPath, err := cleanPath(path) + if err != nil { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("error cleaning path: %w", err)) + } + + finfo, err := c.Stat(ctx, conn) + exists := err == nil && !finfo.NotFound if err != nil { - return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("error listing blockfiles: %w", err)) + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("error getting file info: %w", err)) + } + if !exists { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("file not found: %s", conn.GetFullURI())) + } + singleFile := finfo != nil && !finfo.IsDir + var pathPrefix string + if !singleFile && srcHasSlash { + pathPrefix = cleanedPath + } else { + pathPrefix = filepath.Dir(cleanedPath) } - pathPrefix := getPathPrefix(conn) schemeAndHost := conn.GetSchemeAndHost() + "/" - timeout := DefaultTimeout + var entries []*wshrpc.FileInfo + if singleFile { + entries = []*wshrpc.FileInfo{finfo} + } else { + entries, err = c.ListEntries(ctx, conn, nil) + if err != nil { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("error listing blockfiles: %w", err)) + } + } + + timeout := fstype.DefaultTimeout if opts.Timeout > 0 { timeout = time.Duration(opts.Timeout) * time.Millisecond } @@ -128,14 +157,14 @@ func (c WaveClient) ReadTarStream(ctx context.Context, conn *connparse.Connectio tarClose() cancel() }() - for _, file := range list { + for _, file := range entries { if readerCtx.Err() != nil { - rtn <- wshutil.RespErr[iochantypes.Packet](readerCtx.Err()) + rtn <- wshutil.RespErr[iochantypes.Packet](context.Cause(readerCtx)) return } file.Mode = 0644 - if err = writeHeader(fileutil.ToFsFileInfo(file), file.Path); err != nil { + if err = writeHeader(fileutil.ToFsFileInfo(file), file.Path, singleFile); err != nil { rtn <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("error writing tar header: %w", err)) return } @@ -191,50 +220,37 @@ func (c WaveClient) ListEntries(ctx context.Context, conn *connparse.Connection, if err != nil { return nil, fmt.Errorf("error cleaning path: %w", err) } - fileListOrig, err := filestore.WFS.ListFiles(ctx, zoneId) - if err != nil { - return nil, fmt.Errorf("error listing blockfiles: %w", err) - } + prefix += fspath.Separator var fileList []*wshrpc.FileInfo - for _, wf := range fileListOrig { - fileList = append(fileList, wavefileutil.WaveFileToFileInfo(wf)) - } - if prefix != "" { - var filteredList []*wshrpc.FileInfo - for _, file := range fileList { - if strings.HasPrefix(file.Name, prefix) { - filteredList = append(filteredList, file) - } - } - fileList = filteredList - } - if !opts.All { - var filteredList []*wshrpc.FileInfo - dirMap := make(map[string]any) // the value is max modtime - for _, file := range fileList { - // if there is an extra "/" after the prefix, don't include it - // first strip the prefix - relPath := strings.TrimPrefix(file.Name, prefix) - // then check if there is a "/" after the prefix - if strings.Contains(relPath, "/") { - dirPath := strings.Split(relPath, "/")[0] - dirMap[dirPath] = struct{}{} - continue + dirMap := make(map[string]*wshrpc.FileInfo) + if err := listFilesPrefix(ctx, zoneId, prefix, func(wf *filestore.WaveFile) error { + if !opts.All { + name, isDir := fspath.FirstLevelDir(strings.TrimPrefix(wf.Name, prefix)) + if isDir { + path := fspath.Join(conn.GetPathWithHost(), name) + if _, ok := dirMap[path]; ok { + if dirMap[path].ModTime < wf.ModTs { + dirMap[path].ModTime = wf.ModTs + } + return nil + } + dirMap[path] = &wshrpc.FileInfo{ + Path: path, + Name: name, + Dir: fspath.Dir(path), + Size: 0, + IsDir: true, + SupportsMkdir: false, + Mode: DirMode, + } + fileList = append(fileList, dirMap[path]) + return nil } - filteredList = append(filteredList, file) - } - for dir := range dirMap { - dirName := prefix + dir + "/" - filteredList = append(filteredList, &wshrpc.FileInfo{ - Path: fmt.Sprintf(wavefileutil.WaveFilePathPattern, zoneId, dirName), - Name: dirName, - Dir: dirName, - Size: 0, - IsDir: true, - SupportsMkdir: false, - }) } - fileList = filteredList + fileList = append(fileList, wavefileutil.WaveFileToFileInfo(wf)) + return nil + }); err != nil { + return nil, fmt.Errorf("error listing entries: %w", err) } if opts.Offset > 0 { if opts.Offset >= len(fileList) { @@ -256,14 +272,34 @@ func (c WaveClient) Stat(ctx context.Context, conn *connparse.Connection) (*wshr if zoneId == "" { return nil, fmt.Errorf("zoneid not found in connection") } - fileName, err := cleanPath(conn.Path) + fileName, err := fsutil.CleanPathPrefix(conn.Path) if err != nil { return nil, fmt.Errorf("error cleaning path: %w", err) } fileInfo, err := filestore.WFS.Stat(ctx, zoneId, fileName) if err != nil { if errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("NOTFOUND: %w", err) + // attempt to list the directory + entries, err := c.ListEntries(ctx, conn, nil) + if err != nil { + return nil, fmt.Errorf("error listing entries: %w", err) + } + if len(entries) > 0 { + return &wshrpc.FileInfo{ + Path: conn.GetPathWithHost(), + Name: fileName, + Dir: fsutil.GetParentPathString(fileName), + Size: 0, + IsDir: true, + Mode: DirMode, + }, nil + } else { + return &wshrpc.FileInfo{ + Path: conn.GetPathWithHost(), + Name: fileName, + Dir: fsutil.GetParentPathString(fileName), + NotFound: true}, nil + } } return nil, fmt.Errorf("error getting file info: %w", err) } @@ -283,8 +319,7 @@ func (c WaveClient) PutFile(ctx context.Context, conn *connparse.Connection, dat if err != nil { return fmt.Errorf("error cleaning path: %w", err) } - _, err = filestore.WFS.Stat(ctx, zoneId, fileName) - if err != nil { + if _, err := filestore.WFS.Stat(ctx, zoneId, fileName); err != nil { if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("error getting blockfile info: %w", err) } @@ -298,25 +333,20 @@ func (c WaveClient) PutFile(ctx context.Context, conn *connparse.Connection, dat meta = *data.Info.Meta } } - err := filestore.WFS.MakeFile(ctx, zoneId, fileName, meta, opts) - if err != nil { + if err := filestore.WFS.MakeFile(ctx, zoneId, fileName, meta, opts); err != nil { return fmt.Errorf("error making blockfile: %w", err) } } if data.At != nil && data.At.Offset >= 0 { - err = filestore.WFS.WriteAt(ctx, zoneId, fileName, data.At.Offset, dataBuf) - if errors.Is(err, fs.ErrNotExist) { + if err := filestore.WFS.WriteAt(ctx, zoneId, fileName, data.At.Offset, dataBuf); errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("NOTFOUND: %w", err) - } - if err != nil { + } else if err != nil { return fmt.Errorf("error writing to blockfile: %w", err) } } else { - err = filestore.WFS.WriteFile(ctx, zoneId, fileName, dataBuf) - if errors.Is(err, fs.ErrNotExist) { + if err := filestore.WFS.WriteFile(ctx, zoneId, fileName, dataBuf); errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("NOTFOUND: %w", err) - } - if err != nil { + } else if err != nil { return fmt.Errorf("error writing to blockfile: %w", err) } } @@ -360,8 +390,7 @@ func (c WaveClient) AppendFile(ctx context.Context, conn *connparse.Connection, meta = *data.Info.Meta } } - err := filestore.WFS.MakeFile(ctx, zoneId, fileName, meta, opts) - if err != nil { + if err := filestore.WFS.MakeFile(ctx, zoneId, fileName, meta, opts); err != nil { return fmt.Errorf("error making blockfile: %w", err) } } @@ -393,93 +422,76 @@ func (c WaveClient) MoveInternal(ctx context.Context, srcConn, destConn *connpar if srcConn.Host != destConn.Host { return fmt.Errorf("move internal, src and dest hosts do not match") } - err := c.CopyInternal(ctx, srcConn, destConn, opts) - if err != nil { + if err := c.CopyInternal(ctx, srcConn, destConn, opts); err != nil { return fmt.Errorf("error copying blockfile: %w", err) } - err = c.Delete(ctx, srcConn, opts.Recursive) - if err != nil { + if err := c.Delete(ctx, srcConn, opts.Recursive); err != nil { return fmt.Errorf("error deleting blockfile: %w", err) } return nil } func (c WaveClient) CopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { - if srcConn.Host == destConn.Host { - host := srcConn.Host - srcFileName, err := cleanPath(srcConn.Path) - if err != nil { - return fmt.Errorf("error cleaning source path: %w", err) - } - destFileName, err := cleanPath(destConn.Path) - if err != nil { - return fmt.Errorf("error cleaning destination path: %w", err) - } - err = filestore.WFS.MakeFile(ctx, host, destFileName, wshrpc.FileMeta{}, wshrpc.FileOpts{}) - if err != nil { - return fmt.Errorf("error making source blockfile: %w", err) - } - _, dataBuf, err := filestore.WFS.ReadFile(ctx, host, srcFileName) + return fsutil.PrefixCopyInternal(ctx, srcConn, destConn, c, opts, func(ctx context.Context, zoneId, prefix string) ([]string, error) { + entryList := make([]string, 0) + if err := listFilesPrefix(ctx, zoneId, prefix, func(wf *filestore.WaveFile) error { + entryList = append(entryList, wf.Name) + return nil + }); err != nil { + return nil, err + } + return entryList, nil + }, func(ctx context.Context, srcPath, destPath string) error { + srcHost := srcConn.Host + srcFileName := strings.TrimPrefix(srcPath, srcHost+fspath.Separator) + destHost := destConn.Host + destFileName := strings.TrimPrefix(destPath, destHost+fspath.Separator) + _, dataBuf, err := filestore.WFS.ReadFile(ctx, srcHost, srcFileName) if err != nil { return fmt.Errorf("error reading source blockfile: %w", err) } - err = filestore.WFS.WriteFile(ctx, host, destFileName, dataBuf) - if err != nil { + if err := filestore.WFS.WriteFile(ctx, destHost, destFileName, dataBuf); err != nil { return fmt.Errorf("error writing to destination blockfile: %w", err) } wps.Broker.Publish(wps.WaveEvent{ Event: wps.Event_BlockFile, - Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, host).String()}, + Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, destHost).String()}, Data: &wps.WSFileEventData{ - ZoneId: host, + ZoneId: destHost, FileName: destFileName, FileOp: wps.FileOp_Invalidate, }, }) return nil - } else { - return fmt.Errorf("copy between different hosts not supported") - } + }) } func (c WaveClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, srcClient fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { + if srcConn.Scheme == connparse.ConnectionTypeWave && destConn.Scheme == connparse.ConnectionTypeWave { + return c.CopyInternal(ctx, srcConn, destConn, opts) + } zoneId := destConn.Host if zoneId == "" { return fmt.Errorf("zoneid not found in connection") } - destPrefix := getPathPrefix(destConn) - destPrefix = strings.TrimPrefix(destPrefix, destConn.GetSchemeAndHost()+"/") - log.Printf("CopyRemote: srcConn: %v, destConn: %v, destPrefix: %s\n", srcConn, destConn, destPrefix) - readCtx, cancel := context.WithCancelCause(ctx) - ioch := srcClient.ReadTarStream(readCtx, srcConn, opts) - err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader) error { - if next.Typeflag == tar.TypeDir { - return nil - } - fileName, err := cleanPath(path.Join(destPrefix, next.Name)) - if err != nil { - return fmt.Errorf("error cleaning path: %w", err) + return fsutil.PrefixCopyRemote(ctx, srcConn, destConn, srcClient, c, func(zoneId, path string, size int64, reader io.Reader) error { + dataBuf := make([]byte, size) + if _, err := reader.Read(dataBuf); err != nil { + if !errors.Is(err, io.EOF) { + return fmt.Errorf("error reading tar data: %w", err) + } } - _, err = filestore.WFS.Stat(ctx, zoneId, fileName) - if err != nil { + if _, err := filestore.WFS.Stat(ctx, zoneId, path); err != nil { if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("error getting blockfile info: %w", err) - } - err := filestore.WFS.MakeFile(ctx, zoneId, fileName, nil, wshrpc.FileOpts{}) - if err != nil { - return fmt.Errorf("error making blockfile: %w", err) - } - } - log.Printf("CopyRemote: writing file: %s; size: %d\n", fileName, next.Size) - dataBuf := make([]byte, next.Size) - _, err = reader.Read(dataBuf) - if err != nil { - if !errors.Is(err, io.EOF) { - return fmt.Errorf("error reading tar data: %w", err) + } else { + if err := filestore.WFS.MakeFile(ctx, zoneId, path, wshrpc.FileMeta{}, wshrpc.FileOpts{}); err != nil { + return fmt.Errorf("error making blockfile: %w", err) + } } } - err = filestore.WFS.WriteFile(ctx, zoneId, fileName, dataBuf) - if err != nil { + + if err := filestore.WFS.WriteFile(ctx, zoneId, path, dataBuf); err != nil { return fmt.Errorf("error writing to blockfile: %w", err) } wps.Broker.Publish(wps.WaveEvent{ @@ -487,16 +499,12 @@ func (c WaveClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, zoneId).String()}, Data: &wps.WSFileEventData{ ZoneId: zoneId, - FileName: fileName, + FileName: path, FileOp: wps.FileOp_Invalidate, }, }) return nil - }) - if err != nil { - return fmt.Errorf("error copying tar stream: %w", err) - } - return nil + }, opts) } func (c WaveClient) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { @@ -504,22 +512,40 @@ func (c WaveClient) Delete(ctx context.Context, conn *connparse.Connection, recu if zoneId == "" { return fmt.Errorf("zoneid not found in connection") } - schemeAndHost := conn.GetSchemeAndHost() + "/" + prefix := conn.Path - entries, err := c.ListEntries(ctx, conn, nil) - if err != nil { - return fmt.Errorf("error listing blockfiles: %w", err) + finfo, err := c.Stat(ctx, conn) + exists := err == nil && !finfo.NotFound + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("error getting file info: %w", err) } - if len(entries) > 0 { + if !exists { + return nil + } + + pathsToDelete := make([]string, 0) + + if finfo.IsDir { if !recursive { - return fmt.Errorf("more than one entry, use recursive flag to delete") + return fmt.Errorf("%v is not empty, use recursive flag to delete", prefix) + } + if !strings.HasSuffix(prefix, fspath.Separator) { + prefix += fspath.Separator + } + if err := listFilesPrefix(ctx, zoneId, prefix, func(wf *filestore.WaveFile) error { + pathsToDelete = append(pathsToDelete, wf.Name) + return nil + }); err != nil { + return fmt.Errorf("error listing blockfiles: %w", err) } + } else { + pathsToDelete = append(pathsToDelete, prefix) + } + if len(pathsToDelete) > 0 { errs := make([]error, 0) - for _, entry := range entries { - fileName := strings.TrimPrefix(entry.Path, schemeAndHost) - err = filestore.WFS.DeleteFile(ctx, zoneId, fileName) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting blockfile %s/%s: %w", zoneId, fileName, err)) + for _, entry := range pathsToDelete { + if err := filestore.WFS.DeleteFile(ctx, zoneId, entry); err != nil { + errs = append(errs, fmt.Errorf("error deleting blockfile %s/%s: %w", zoneId, entry, err)) continue } wps.Broker.Publish(wps.WaveEvent{ @@ -527,7 +553,7 @@ func (c WaveClient) Delete(ctx context.Context, conn *connparse.Connection, recu Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, zoneId).String()}, Data: &wps.WSFileEventData{ ZoneId: zoneId, - FileName: fileName, + FileName: entry, FileOp: wps.FileOp_Delete, }, }) @@ -539,27 +565,51 @@ func (c WaveClient) Delete(ctx context.Context, conn *connparse.Connection, recu return nil } -func (c WaveClient) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) { - newPath := path.Join(append([]string{conn.Path}, parts...)...) +func listFilesPrefix(ctx context.Context, zoneId, prefix string, entryCallback func(*filestore.WaveFile) error) error { + if zoneId == "" { + return fmt.Errorf("zoneid not found in connection") + } + fileListOrig, err := filestore.WFS.ListFiles(ctx, zoneId) + if err != nil { + return fmt.Errorf("error listing blockfiles: %w", err) + } + for _, wf := range fileListOrig { + if prefix == "" || strings.HasPrefix(wf.Name, prefix) { + entryCallback(wf) + } + } + return nil +} + +func (c WaveClient) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (*wshrpc.FileInfo, error) { + newPath := fspath.Join(append([]string{conn.Path}, parts...)...) newPath, err := cleanPath(newPath) if err != nil { - return "", fmt.Errorf("error cleaning path: %w", err) + return nil, fmt.Errorf("error cleaning path: %w", err) + } + conn.Path = newPath + return c.Stat(ctx, conn) +} + +func (c WaveClient) GetCapability() wshrpc.FileShareCapability { + return wshrpc.FileShareCapability{ + CanAppend: true, + CanMkdir: false, } - return newPath, nil } func cleanPath(path string) (string, error) { - if path == "" { - return "", fmt.Errorf("path is empty") + if path == "" || path == fspath.Separator { + return "", nil } - if strings.HasPrefix(path, "/") { + if strings.HasPrefix(path, fspath.Separator) { path = path[1:] } if strings.HasPrefix(path, "~") || strings.HasPrefix(path, ".") || strings.HasPrefix(path, "..") { return "", fmt.Errorf("wavefile path cannot start with ~, ., or ..") } var newParts []string - for _, part := range strings.Split(path, "/") { + for _, part := range strings.Split(path, fspath.Separator) { if part == ".." { if len(newParts) > 0 { newParts = newParts[:len(newParts)-1] @@ -568,19 +618,9 @@ func cleanPath(path string) (string, error) { newParts = append(newParts, part) } } - return strings.Join(newParts, "/"), nil + return fspath.Join(newParts...), nil } func (c WaveClient) GetConnectionType() string { return connparse.ConnectionTypeWave } - -func getPathPrefix(conn *connparse.Connection) string { - fullUri := conn.GetFullURI() - pathPrefix := fullUri - lastSlash := strings.LastIndex(fullUri, "/") - if lastSlash > 10 && lastSlash < len(fullUri)-1 { - pathPrefix = fullUri[:lastSlash+1] - } - return pathPrefix -} diff --git a/pkg/remote/fileshare/wshfs/wshfs.go b/pkg/remote/fileshare/wshfs/wshfs.go index 61816ea576..ae0930e864 100644 --- a/pkg/remote/fileshare/wshfs/wshfs.go +++ b/pkg/remote/fileshare/wshfs/wshfs.go @@ -4,24 +4,18 @@ package wshfs import ( - "bytes" "context" - "encoding/base64" "fmt" - "io" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fsutil" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshutil" ) -const ( - ThirtySeconds = 30 * 1000 -) - // This needs to be set by whoever initializes the client, either main-server or wshcmd-connserver var RpcClient *wshutil.WshRpc @@ -35,47 +29,7 @@ func NewWshClient() *WshClient { func (c WshClient) Read(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) (*wshrpc.FileData, error) { rtnCh := c.ReadStream(ctx, conn, data) - var fileData *wshrpc.FileData - firstPk := true - isDir := false - var fileBuf bytes.Buffer - for respUnion := range rtnCh { - if respUnion.Error != nil { - return nil, respUnion.Error - } - resp := respUnion.Response - if firstPk { - firstPk = false - // first packet has the fileinfo - if resp.Info == nil { - return nil, fmt.Errorf("stream file protocol error, first pk fileinfo is empty") - } - fileData = &resp - if fileData.Info.IsDir { - isDir = true - } - continue - } - if isDir { - if len(resp.Entries) == 0 { - continue - } - fileData.Entries = append(fileData.Entries, resp.Entries...) - } else { - if resp.Data64 == "" { - continue - } - decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(resp.Data64))) - _, err := io.Copy(&fileBuf, decoder) - if err != nil { - return nil, fmt.Errorf("stream file, failed to decode base64 data: %w", err) - } - } - } - if !isDir { - fileData.Data64 = base64.StdEncoding.EncodeToString(fileBuf.Bytes()) - } - return fileData, nil + return fsutil.ReadStreamToFileData(ctx, rtnCh) } func (c WshClient) ReadStream(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { @@ -90,7 +44,7 @@ func (c WshClient) ReadStream(ctx context.Context, conn *connparse.Connection, d func (c WshClient) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { timeout := opts.Timeout if timeout == 0 { - timeout = ThirtySeconds + timeout = fstype.DefaultTimeout.Milliseconds() } return wshclient.RemoteTarStreamCommand(RpcClient, wshrpc.CommandRemoteStreamTarData{Path: conn.Path, Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host), Timeout: timeout}) } @@ -155,9 +109,9 @@ func (c WshClient) MoveInternal(ctx context.Context, srcConn, destConn *connpars } timeout := opts.Timeout if timeout == 0 { - timeout = ThirtySeconds + timeout = fstype.DefaultTimeout.Milliseconds() } - return wshclient.RemoteFileMoveCommand(RpcClient, wshrpc.CommandRemoteFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) + return wshclient.RemoteFileMoveCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) } func (c WshClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, _ fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { @@ -170,23 +124,23 @@ func (c WshClient) CopyInternal(ctx context.Context, srcConn, destConn *connpars } timeout := opts.Timeout if timeout == 0 { - timeout = ThirtySeconds + timeout = fstype.DefaultTimeout.Milliseconds() } - return wshclient.RemoteFileCopyCommand(RpcClient, wshrpc.CommandRemoteFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) + return wshclient.RemoteFileCopyCommand(RpcClient, wshrpc.CommandFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) } func (c WshClient) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { return wshclient.RemoteFileDeleteCommand(RpcClient, wshrpc.CommandDeleteFileData{Path: conn.Path, Recursive: recursive}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) } -func (c WshClient) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) { - finfo, err := wshclient.RemoteFileJoinCommand(RpcClient, append([]string{conn.Path}, parts...), &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) - if err != nil { - return "", err - } - return finfo.Path, nil +func (c WshClient) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (*wshrpc.FileInfo, error) { + return wshclient.RemoteFileJoinCommand(RpcClient, append([]string{conn.Path}, parts...), &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) } func (c WshClient) GetConnectionType() string { return connparse.ConnectionTypeWsh } + +func (c WshClient) GetCapability() wshrpc.FileShareCapability { + return wshrpc.FileShareCapability{CanAppend: true, CanMkdir: true} +} diff --git a/pkg/util/fileutil/fileutil.go b/pkg/util/fileutil/fileutil.go index 4c894f190c..426fe1154e 100644 --- a/pkg/util/fileutil/fileutil.go +++ b/pkg/util/fileutil/fileutil.go @@ -19,6 +19,7 @@ import ( ) func FixPath(path string) (string, error) { + origPath := path var err error if strings.HasPrefix(path, "~") { path = filepath.Join(wavebase.GetHomeDir(), path[1:]) @@ -28,6 +29,9 @@ func FixPath(path string) (string, error) { return "", err } } + if strings.HasSuffix(origPath, "/") && !strings.HasSuffix(path, "/") { + path += "/" + } return path, nil } @@ -61,7 +65,6 @@ func WinSymlinkDir(path string, bits os.FileMode) bool { // does not return "application/octet-stream" as this is considered a detection failure // can pass an existing fileInfo to avoid re-statting the file // falls back to text/plain for 0 byte files - func DetectMimeType(path string, fileInfo fs.FileInfo, extended bool) string { if fileInfo == nil { statRtn, err := os.Stat(path) @@ -140,6 +143,15 @@ func DetectMimeTypeWithDirEnt(path string, dirEnt fs.DirEntry) string { return "" } +func AddMimeTypeToFileInfo(path string, fileInfo *wshrpc.FileInfo) { + if fileInfo == nil { + return + } + if fileInfo.MimeType == "" { + fileInfo.MimeType = DetectMimeType(path, ToFsFileInfo(fileInfo), false) + } +} + var ( systemBinDirs = []string{ "/bin/", diff --git a/pkg/util/iochan/iochan.go b/pkg/util/iochan/iochan.go index 98fb94a196..4bb5292cf4 100644 --- a/pkg/util/iochan/iochan.go +++ b/pkg/util/iochan/iochan.go @@ -11,8 +11,10 @@ import ( "errors" "fmt" "io" + "log" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) @@ -22,6 +24,7 @@ func ReaderChan(ctx context.Context, r io.Reader, chunkSize int64, callback func ch := make(chan wshrpc.RespOrErrorUnion[iochantypes.Packet], 32) go func() { defer func() { + log.Printf("Closing ReaderChan\n") close(ch) callback() }() @@ -60,7 +63,7 @@ func WriterChan(ctx context.Context, w io.Writer, ch <-chan wshrpc.RespOrErrorUn go func() { defer func() { if ctx.Err() != nil { - drainChannel(ch) + utilfn.DrainChannelSafe(ch, "WriterChan") } callback() }() @@ -97,10 +100,3 @@ func WriterChan(ctx context.Context, w io.Writer, ch <-chan wshrpc.RespOrErrorUn } }() } - -func drainChannel(ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet]) { - go func() { - for range ch { - } - }() -} diff --git a/pkg/util/tarcopy/tarcopy.go b/pkg/util/tarcopy/tarcopy.go index 06e008811c..d8888719de 100644 --- a/pkg/util/tarcopy/tarcopy.go +++ b/pkg/util/tarcopy/tarcopy.go @@ -14,78 +14,98 @@ import ( "log" "path/filepath" "strings" - "time" "github.com/wavetermdev/waveterm/pkg/util/iochan" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) const ( - maxRetries = 5 - retryDelay = 10 * time.Millisecond tarCopySrcName = "TarCopySrc" tarCopyDestName = "TarCopyDest" pipeReaderName = "pipe reader" pipeWriterName = "pipe writer" tarWriterName = "tar writer" + + // custom flag to indicate that the source is a single file + SingleFile = "singlefile" ) // TarCopySrc creates a tar stream writer and returns a channel to send the tar stream to. -// writeHeader is a function that writes the tar header for the file. +// writeHeader is a function that writes the tar header for the file. If only a single file is being written, the singleFile flag should be set to true. // writer is the tar writer to write the file data to. // close is a function that closes the tar writer and internal pipe writer. -func TarCopySrc(ctx context.Context, pathPrefix string) (outputChan chan wshrpc.RespOrErrorUnion[iochantypes.Packet], writeHeader func(fi fs.FileInfo, file string) error, writer io.Writer, close func()) { +func TarCopySrc(ctx context.Context, pathPrefix string) (outputChan chan wshrpc.RespOrErrorUnion[iochantypes.Packet], writeHeader func(fi fs.FileInfo, file string, singleFile bool) error, writer io.Writer, close func()) { pipeReader, pipeWriter := io.Pipe() tarWriter := tar.NewWriter(pipeWriter) rtnChan := iochan.ReaderChan(ctx, pipeReader, wshrpc.FileChunkSize, func() { - gracefulClose(pipeReader, tarCopySrcName, pipeReaderName) + log.Printf("Closing pipe reader\n") + utilfn.GracefulClose(pipeReader, tarCopySrcName, pipeReaderName) }) - return rtnChan, func(fi fs.FileInfo, file string) error { + singleFileFlagSet := false + + return rtnChan, func(fi fs.FileInfo, path string, singleFile bool) error { // generate tar header - header, err := tar.FileInfoHeader(fi, file) + header, err := tar.FileInfoHeader(fi, path) if err != nil { return err } - header.Name = filepath.Clean(strings.TrimPrefix(file, pathPrefix)) - if err := validatePath(header.Name); err != nil { + if singleFile { + if singleFileFlagSet { + return errors.New("attempting to write multiple files to a single file tar stream") + } + + header.PAXRecords = map[string]string{SingleFile: "true"} + singleFileFlagSet = true + } + + path, err = fixPath(path, pathPrefix) + if err != nil { return err } + // skip if path is empty, which means the file is the root directory + if path == "" { + return nil + } + header.Name = path + + log.Printf("TarCopySrc: header name: %v\n", header.Name) + // write header if err := tarWriter.WriteHeader(header); err != nil { return err } return nil }, tarWriter, func() { - gracefulClose(tarWriter, tarCopySrcName, tarWriterName) - gracefulClose(pipeWriter, tarCopySrcName, pipeWriterName) + log.Printf("Closing tar writer\n") + utilfn.GracefulClose(tarWriter, tarCopySrcName, tarWriterName) + utilfn.GracefulClose(pipeWriter, tarCopySrcName, pipeWriterName) } } -func validatePath(path string) error { +func fixPath(path, prefix string) (string, error) { + path = strings.TrimPrefix(strings.TrimPrefix(filepath.Clean(strings.TrimPrefix(path, prefix)), "/"), "\\") if strings.Contains(path, "..") { - return fmt.Errorf("invalid tar path containing directory traversal: %s", path) - } - if strings.HasPrefix(path, "/") { - return fmt.Errorf("invalid tar path starting with /: %s", path) + return "", fmt.Errorf("invalid tar path containing directory traversal: %s", path) } - return nil + return path, nil } // TarCopyDest reads a tar stream from a channel and writes the files to the destination. -// readNext is a function that is called for each file in the tar stream to read the file data. It should return an error if the file cannot be read. +// readNext is a function that is called for each file in the tar stream to read the file data. If only a single file is being written from the tar src, the singleFile flag will be set in this callback. It should return an error if the file cannot be read. // The function returns an error if the tar stream cannot be read. -func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet], readNext func(next *tar.Header, reader *tar.Reader) error) error { +func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet], readNext func(next *tar.Header, reader *tar.Reader, singleFile bool) error) error { pipeReader, pipeWriter := io.Pipe() iochan.WriterChan(ctx, pipeWriter, ch, func() { - gracefulClose(pipeWriter, tarCopyDestName, pipeWriterName) + utilfn.GracefulClose(pipeWriter, tarCopyDestName, pipeWriterName) }, cancel) tarReader := tar.NewReader(pipeReader) defer func() { - if !gracefulClose(pipeReader, tarCopyDestName, pipeReaderName) { + if !utilfn.GracefulClose(pipeReader, tarCopyDestName, pipeReaderName) { // If the pipe reader cannot be closed, cancel the context. This should kill the writer goroutine. cancel(nil) } @@ -110,27 +130,15 @@ func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan return err } } - err = readNext(next, tarReader) + + // Check for directory traversal + if strings.Contains(next.Name, "..") { + return fmt.Errorf("invalid tar path containing directory traversal: %s", next.Name) + } + err = readNext(next, tarReader, next.PAXRecords != nil && next.PAXRecords[SingleFile] == "true") if err != nil { return err } } } } - -func gracefulClose(closer io.Closer, debugName string, closerName string) bool { - closed := false - for retries := 0; retries < maxRetries; retries++ { - if err := closer.Close(); err != nil { - log.Printf("%s: error closing %s: %v, trying again in %dms\n", debugName, closerName, err, retryDelay.Milliseconds()) - time.Sleep(retryDelay) - continue - } - closed = true - break - } - if !closed { - log.Printf("%s: unable to close %s after %d retries\n", debugName, closerName, maxRetries) - } - return closed -} diff --git a/pkg/util/utilfn/utilfn.go b/pkg/util/utilfn/utilfn.go index 49a8133f59..f5765e0078 100644 --- a/pkg/util/utilfn/utilfn.go +++ b/pkg/util/utilfn/utilfn.go @@ -15,6 +15,7 @@ import ( "fmt" "hash/fnv" "io" + "log" "math" mathrand "math/rand" "os" @@ -1032,3 +1033,44 @@ func SendWithCtxCheck[T any](ctx context.Context, ch chan<- T, val T) bool { return true } } + +const ( + maxRetries = 5 + retryDelay = 10 * time.Millisecond +) + +func GracefulClose(closer io.Closer, debugName, closerName string) bool { + closed := false + for retries := 0; retries < maxRetries; retries++ { + if err := closer.Close(); err != nil { + log.Printf("%s: error closing %s: %v, trying again in %dms\n", debugName, closerName, err, retryDelay.Milliseconds()) + time.Sleep(retryDelay) + continue + } + closed = true + break + } + if !closed { + log.Printf("%s: unable to close %s after %d retries\n", debugName, closerName, maxRetries) + } + return closed +} + +// DrainChannelSafe will drain a channel until it is empty or until a timeout is reached. +// WARNING: This function will panic if the channel is not drained within the timeout. +func DrainChannelSafe[T any](ch <-chan T, debugName string) { + drainTimeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + go func() { + defer cancel() + for { + select { + case <-drainTimeoutCtx.Done(): + panic(debugName + ": timeout draining channel") + case _, ok := <-ch: + if !ok { + return + } + } + } + }() +} diff --git a/pkg/util/wavefileutil/wavefileutil.go b/pkg/util/wavefileutil/wavefileutil.go index 81b09cf288..7334bce7aa 100644 --- a/pkg/util/wavefileutil/wavefileutil.go +++ b/pkg/util/wavefileutil/wavefileutil.go @@ -4,6 +4,8 @@ import ( "fmt" "github.com/wavetermdev/waveterm/pkg/filestore" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fsutil" + "github.com/wavetermdev/waveterm/pkg/util/fileutil" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) @@ -13,14 +15,17 @@ const ( func WaveFileToFileInfo(wf *filestore.WaveFile) *wshrpc.FileInfo { path := fmt.Sprintf(WaveFilePathPattern, wf.ZoneId, wf.Name) - return &wshrpc.FileInfo{ + rtn := &wshrpc.FileInfo{ Path: path, + Dir: fsutil.GetParentPathString(path), Name: wf.Name, Opts: &wf.Opts, Size: wf.Size, Meta: &wf.Meta, SupportsMkdir: false, } + fileutil.AddMimeTypeToFileInfo(path, rtn) + return rtn } func WaveFileListToFileInfoList(wfList []*filestore.WaveFile) []*wshrpc.FileInfo { diff --git a/pkg/wavebase/wavebase.go b/pkg/wavebase/wavebase.go index 2d2c30064b..52b365124a 100644 --- a/pkg/wavebase/wavebase.go +++ b/pkg/wavebase/wavebase.go @@ -148,17 +148,6 @@ func ExpandHomeDirSafe(pathStr string) string { return path } -func ReplaceHomeDir(pathStr string) string { - homeDir := GetHomeDir() - if pathStr == homeDir { - return "~" - } - if strings.HasPrefix(pathStr, homeDir+"/") { - return "~" + pathStr[len(homeDir):] - } - return pathStr -} - func GetDomainSocketName() string { return filepath.Join(GetWaveDataDir(), DomainSocketBaseName) } diff --git a/pkg/web/web.go b/pkg/web/web.go index 7450d6cb5a..1e89f4bca9 100644 --- a/pkg/web/web.go +++ b/pkg/web/web.go @@ -27,6 +27,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/schema" "github.com/wavetermdev/waveterm/pkg/service" + "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" @@ -258,10 +259,7 @@ func handleRemoteStreamFile(w http.ResponseWriter, req *http.Request, conn strin return } // if loop didn't finish naturally clear it out - go func() { - for range rtnCh { - } - }() + utilfn.DrainChannelSafe(rtnCh, "handleRemoteStreamFile") }() ctx := req.Context() for { diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 034365eecb..9e1a97af22 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -70,6 +70,12 @@ func ConnListCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) return resp, err } +// command "connlistaws", wshserver.ConnListAWSCommand +func ConnListAWSCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]string, error) { + resp, err := sendRpcRequestCallHelper[[]string](w, "connlistaws", nil, opts) + return resp, err +} + // command "connreinstallwsh", wshserver.ConnReinstallWshCommand func ConnReinstallWshCommand(w *wshutil.WshRpc, data wshrpc.ConnExtData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "connreinstallwsh", data, opts) @@ -226,6 +232,12 @@ func FileInfoCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOp return resp, err } +// command "filejoin", wshserver.FileJoinCommand +func FileJoinCommand(w *wshutil.WshRpc, data []string, opts *wshrpc.RpcOpts) (*wshrpc.FileInfo, error) { + resp, err := sendRpcRequestCallHelper[*wshrpc.FileInfo](w, "filejoin", data, opts) + return resp, err +} + // command "filelist", wshserver.FileListCommand func FileListCommand(w *wshutil.WshRpc, data wshrpc.FileListData, opts *wshrpc.RpcOpts) ([]*wshrpc.FileInfo, error) { resp, err := sendRpcRequestCallHelper[[]*wshrpc.FileInfo](w, "filelist", data, opts) @@ -255,6 +267,17 @@ func FileReadCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOp return resp, err } +// command "filereadstream", wshserver.FileReadStreamCommand +func FileReadStreamCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { + return sendRpcRequestResponseStreamHelper[wshrpc.FileData](w, "filereadstream", data, opts) +} + +// command "filesharecapability", wshserver.FileShareCapabilityCommand +func FileShareCapabilityCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) (wshrpc.FileShareCapability, error) { + resp, err := sendRpcRequestCallHelper[wshrpc.FileShareCapability](w, "filesharecapability", data, opts) + return resp, err +} + // command "filestreamtar", wshserver.FileStreamTarCommand func FileStreamTarCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTarData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { return sendRpcRequestResponseStreamHelper[iochantypes.Packet](w, "filestreamtar", data, opts) @@ -327,7 +350,7 @@ func RecordTEventCommand(w *wshutil.WshRpc, data telemetrydata.TEvent, opts *wsh } // command "remotefilecopy", wshserver.RemoteFileCopyCommand -func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteFileCopyData, opts *wshrpc.RpcOpts) error { +func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "remotefilecopy", data, opts) return err } @@ -351,7 +374,7 @@ func RemoteFileJoinCommand(w *wshutil.WshRpc, data []string, opts *wshrpc.RpcOpt } // command "remotefilemove", wshserver.RemoteFileMoveCommand -func RemoteFileMoveCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteFileCopyData, opts *wshrpc.RpcOpts) error { +func RemoteFileMoveCommand(w *wshutil.WshRpc, data wshrpc.CommandFileCopyData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "remotefilemove", data, opts) return err } diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index 711de2e26e..1a6221b674 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -18,6 +18,7 @@ import ( "time" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/suggestion" "github.com/wavetermdev/waveterm/pkg/util/fileutil" @@ -30,10 +31,6 @@ import ( "github.com/wavetermdev/waveterm/pkg/wshutil" ) -const ( - DefaultTimeout = 30 * time.Second -) - type ServerImpl struct { LogWriter io.Writer } @@ -146,7 +143,7 @@ func (impl *ServerImpl) remoteStreamFileRegular(ctx context.Context, path string if err != nil { return fmt.Errorf("cannot open file %q: %w", path, err) } - defer fd.Close() + defer utilfn.GracefulClose(fd, "remoteStreamFileRegular", path) var filePos int64 if !byteRange.All && byteRange.Start > 0 { _, err := fd.Seek(byteRange.Start, io.SeekStart) @@ -240,8 +237,8 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. if opts == nil { opts = &wshrpc.FileCopyOpts{} } - recursive := opts.Recursive - logPrintfDev("RemoteTarStreamCommand: path=%s\n", path) + log.Printf("RemoteTarStreamCommand: path=%s\n", path) + srcHasSlash := strings.HasSuffix(path, "/") path, err := wavebase.ExpandHomeDir(path) if err != nil { return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err)) @@ -253,18 +250,14 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. } var pathPrefix string - if finfo.IsDir() && strings.HasSuffix(cleanedPath, "/") { + singleFile := !finfo.IsDir() + if !singleFile && srcHasSlash { pathPrefix = cleanedPath } else { - pathPrefix = filepath.Dir(cleanedPath) + "/" - } - if finfo.IsDir() { - if !recursive { - return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot create tar stream for %q: %w", path, errors.New("directory copy requires recursive option"))) - } + pathPrefix = filepath.Dir(cleanedPath) } - timeout := DefaultTimeout + timeout := fstype.DefaultTimeout if opts.Timeout > 0 { timeout = time.Duration(opts.Timeout) * time.Millisecond } @@ -283,7 +276,7 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. if err != nil { return err } - if err = writeHeader(info, path); err != nil { + if err = writeHeader(info, path, singleFile); err != nil { return err } // if not a dir, write file content @@ -292,6 +285,7 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. if err != nil { return err } + defer utilfn.GracefulClose(data, "RemoteTarStreamCommand", path) if _, err := io.Copy(fileWriter, data); err != nil { return err } @@ -300,10 +294,10 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. } log.Printf("RemoteTarStreamCommand: starting\n") err = nil - if finfo.IsDir() { - err = filepath.Walk(path, walkFunc) + if singleFile { + err = walkFunc(cleanedPath, finfo, nil) } else { - err = walkFunc(path, finfo, nil) + err = filepath.Walk(cleanedPath, walkFunc) } if err != nil { rtn <- wshutil.RespErr[iochantypes.Packet](err) @@ -314,7 +308,7 @@ func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc. return rtn } -func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandRemoteFileCopyData) error { +func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { log.Printf("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri) opts := data.Opts if opts == nil { @@ -331,19 +325,25 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } destPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(destConn.Path)) destinfo, err := os.Stat(destPathCleaned) - if err == nil { - if !destinfo.IsDir() { - if !overwrite { - return fmt.Errorf("destination %q already exists, use overwrite option", destPathCleaned) - } else { - err := os.Remove(destPathCleaned) - if err != nil { - return fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err) - } + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) + } + } + + destExists := destinfo != nil + destIsDir := destExists && destinfo.IsDir() + destHasSlash := strings.HasSuffix(destUri, "/") + + if destExists && !destIsDir { + if !overwrite { + return fmt.Errorf("file already exists at destination %q, use overwrite option", destPathCleaned) + } else { + err := os.Remove(destPathCleaned) + if err != nil { + return fmt.Errorf("cannot remove file %q: %w", destPathCleaned, err) } } - } else if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) } srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) if err != nil { @@ -351,13 +351,13 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } copyFileFunc := func(path string, finfo fs.FileInfo, srcFile io.Reader) (int64, error) { - destinfo, err = os.Stat(path) + nextinfo, err := os.Stat(path) if err != nil && !errors.Is(err, fs.ErrNotExist) { return 0, fmt.Errorf("cannot stat file %q: %w", path, err) } - if destinfo != nil { - if destinfo.IsDir() { + if nextinfo != nil { + if nextinfo.IsDir() { if !finfo.IsDir() { // try to create file in directory path = filepath.Join(path, filepath.Base(finfo.Name())) @@ -393,10 +393,12 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } if finfo.IsDir() { + log.Printf("RemoteFileCopyCommand: making dirs %s\n", path) err := os.MkdirAll(path, finfo.Mode()) if err != nil { return 0, fmt.Errorf("cannot create directory %q: %w", path, err) } + return 0, nil } else { err := os.MkdirAll(filepath.Dir(path), 0755) if err != nil { @@ -408,7 +410,7 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return 0, fmt.Errorf("cannot create new file %q: %w", path, err) } - defer file.Close() + defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", path) _, err = io.Copy(file, srcFile) if err != nil { return 0, fmt.Errorf("cannot write file %q: %w", path, err) @@ -426,19 +428,25 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } if srcFileStat.IsDir() { + var srcPathPrefix string + if destIsDir { + srcPathPrefix = filepath.Dir(srcPathCleaned) + } else { + srcPathPrefix = srcPathCleaned + } err = filepath.Walk(srcPathCleaned, func(path string, info fs.FileInfo, err error) error { if err != nil { return err } srcFilePath := path - destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathCleaned)) + destFilePath := filepath.Join(destPathCleaned, strings.TrimPrefix(path, srcPathPrefix)) var file *os.File if !info.IsDir() { file, err = os.Open(srcFilePath) if err != nil { return fmt.Errorf("cannot open file %q: %w", srcFilePath, err) } - defer file.Close() + defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcFilePath) } _, err = copyFileFunc(destFilePath, info, file) return err @@ -451,14 +459,20 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return fmt.Errorf("cannot open file %q: %w", srcPathCleaned, err) } - defer file.Close() - _, err = copyFileFunc(destPathCleaned, srcFileStat, file) + defer utilfn.GracefulClose(file, "RemoteFileCopyCommand", srcPathCleaned) + var destFilePath string + if destHasSlash { + destFilePath = filepath.Join(destPathCleaned, filepath.Base(srcPathCleaned)) + } else { + destFilePath = destPathCleaned + } + _, err = copyFileFunc(destFilePath, srcFileStat, file) if err != nil { return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) } } } else { - timeout := DefaultTimeout + timeout := fstype.DefaultTimeout if opts.Timeout > 0 { timeout = time.Duration(opts.Timeout) * time.Millisecond } @@ -470,16 +484,17 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C numFiles := 0 numSkipped := 0 totalBytes := int64(0) - err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader) error { - // Check for directory traversal - if strings.Contains(next.Name, "..") { - log.Printf("skipping file with unsafe path: %q\n", next.Name) - numSkipped++ - return nil - } + + err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader, singleFile bool) error { numFiles++ + nextpath := filepath.Join(destPathCleaned, next.Name) + log.Printf("RemoteFileCopyCommand: copying %q to %q\n", next.Name, nextpath) + if singleFile && !destHasSlash { + // custom flag to indicate that the source is a single file, not a directory the contents of a directory + nextpath = destPathCleaned + } finfo := next.FileInfo() - n, err := copyFileFunc(filepath.Join(destPathCleaned, next.Name), finfo, reader) + n, err := copyFileFunc(nextpath, finfo, reader) if err != nil { return fmt.Errorf("cannot copy file %q: %w", next.Name, err) } @@ -571,8 +586,8 @@ func (impl *ServerImpl) RemoteListEntriesCommand(ctx context.Context, data wshrp func statToFileInfo(fullPath string, finfo fs.FileInfo, extended bool) *wshrpc.FileInfo { mimeType := fileutil.DetectMimeType(fullPath, finfo, extended) rtn := &wshrpc.FileInfo{ - Path: wavebase.ReplaceHomeDir(fullPath), - Dir: computeDirPart(fullPath, finfo.IsDir()), + Path: fullPath, + Dir: computeDirPart(fullPath), Name: finfo.Name(), Size: finfo.Size(), Mode: finfo.Mode(), @@ -602,7 +617,7 @@ func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool { if err != nil { return true } - fd.Close() + utilfn.GracefulClose(fd, "checkIsReadOnly", tmpFileName) os.Remove(tmpFileName) return false } @@ -611,20 +626,16 @@ func checkIsReadOnly(path string, fileInfo fs.FileInfo, exists bool) bool { if err != nil { return true } - file.Close() + utilfn.GracefulClose(file, "checkIsReadOnly", path) return false } -func computeDirPart(path string, isDir bool) string { +func computeDirPart(path string) string { path = filepath.Clean(wavebase.ExpandHomeDirSafe(path)) path = filepath.ToSlash(path) if path == "/" { return "/" } - path = strings.TrimSuffix(path, "/") - if isDir { - return path - } return filepath.Dir(path) } @@ -633,8 +644,8 @@ func (*ServerImpl) fileInfoInternal(path string, extended bool) (*wshrpc.FileInf finfo, err := os.Stat(cleanedPath) if os.IsNotExist(err) { return &wshrpc.FileInfo{ - Path: wavebase.ReplaceHomeDir(path), - Dir: computeDirPart(path, false), + Path: path, + Dir: computeDirPart(path), NotFound: true, ReadOnly: checkIsReadOnly(cleanedPath, finfo, false), SupportsMkdir: true, @@ -689,12 +700,13 @@ func (impl *ServerImpl) RemoteFileTouchCommand(ctx context.Context, path string) return nil } -func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandRemoteFileCopyData) error { +func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { logPrintfDev("RemoteFileCopyCommand: src=%s, dest=%s\n", data.SrcUri, data.DestUri) opts := data.Opts destUri := data.DestUri srcUri := data.SrcUri overwrite := opts != nil && opts.Overwrite + recursive := opts != nil && opts.Recursive destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) if err != nil { @@ -722,7 +734,14 @@ func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.C } if srcConn.Host == destConn.Host { srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - err := os.Rename(srcPathCleaned, destPathCleaned) + finfo, err := os.Stat(srcPathCleaned) + if err != nil { + return fmt.Errorf("cannot stat file %q: %w", srcPathCleaned, err) + } + if finfo.IsDir() && !recursive { + return fmt.Errorf("cannot move directory %q, recursive option not specified", srcUri) + } + err = os.Rename(srcPathCleaned, destPathCleaned) if err != nil { return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err) } @@ -799,7 +818,7 @@ func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileD if err != nil { return fmt.Errorf("cannot open file %q: %w", path, err) } - defer file.Close() + defer utilfn.GracefulClose(file, "RemoteWriteFileCommand", path) if atOffset > 0 && !append { n, err = file.WriteAt(dataBytes[:n], atOffset) } else { diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 8dc06a894b..5d5e3ec52e 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -51,31 +51,36 @@ const ( // TODO generate these constants from the interface const ( - Command_Authenticate = "authenticate" // special - Command_AuthenticateToken = "authenticatetoken" // special - Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only) - Command_RouteAnnounce = "routeannounce" // special (for routing) - Command_RouteUnannounce = "routeunannounce" // special (for routing) - Command_Message = "message" - Command_GetMeta = "getmeta" - Command_SetMeta = "setmeta" - Command_SetView = "setview" - Command_ControllerInput = "controllerinput" - Command_ControllerRestart = "controllerrestart" - Command_ControllerStop = "controllerstop" - Command_ControllerResync = "controllerresync" - Command_FileAppend = "fileappend" - Command_FileAppendIJson = "fileappendijson" - Command_Mkdir = "mkdir" - Command_ResolveIds = "resolveids" - Command_BlockInfo = "blockinfo" - Command_CreateBlock = "createblock" - Command_DeleteBlock = "deleteblock" - Command_FileWrite = "filewrite" - Command_FileRead = "fileread" - Command_FileMove = "filemove" - Command_FileCopy = "filecopy" - Command_FileStreamTar = "filestreamtar" + Command_Authenticate = "authenticate" // special + Command_AuthenticateToken = "authenticatetoken" // special + Command_Dispose = "dispose" // special (disposes of the route, for multiproxy only) + Command_RouteAnnounce = "routeannounce" // special (for routing) + Command_RouteUnannounce = "routeunannounce" // special (for routing) + Command_Message = "message" + Command_GetMeta = "getmeta" + Command_SetMeta = "setmeta" + Command_SetView = "setview" + Command_ControllerInput = "controllerinput" + Command_ControllerRestart = "controllerrestart" + Command_ControllerStop = "controllerstop" + Command_ControllerResync = "controllerresync" + Command_Mkdir = "mkdir" + Command_ResolveIds = "resolveids" + Command_BlockInfo = "blockinfo" + Command_CreateBlock = "createblock" + Command_DeleteBlock = "deleteblock" + + Command_FileWrite = "filewrite" + Command_FileRead = "fileread" + Command_FileReadStream = "filereadstream" + Command_FileMove = "filemove" + Command_FileCopy = "filecopy" + Command_FileStreamTar = "filestreamtar" + Command_FileAppend = "fileappend" + Command_FileAppendIJson = "fileappendijson" + Command_FileJoin = "filejoin" + Command_FileShareCapability = "filesharecapability" + Command_EventPublish = "eventpublish" Command_EventRecv = "eventrecv" Command_EventSub = "eventsub" @@ -113,6 +118,7 @@ const ( Command_ConnConnect = "connconnect" Command_ConnDisconnect = "conndisconnect" Command_ConnList = "connlist" + Command_ConnListAWS = "connlistaws" Command_WslList = "wsllist" Command_WslDefaultDistro = "wsldefaultdistro" Command_DismissWshFail = "dismisswshfail" @@ -159,6 +165,7 @@ type WshRpcInterface interface { DeleteBlockCommand(ctx context.Context, data CommandDeleteBlockData) error DeleteSubBlockCommand(ctx context.Context, data CommandDeleteBlockData) error WaitForRouteCommand(ctx context.Context, data CommandWaitForRouteData) (bool, error) + FileMkdirCommand(ctx context.Context, data FileData) error FileCreateCommand(ctx context.Context, data FileData) error FileDeleteCommand(ctx context.Context, data CommandDeleteFileData) error @@ -166,12 +173,16 @@ type WshRpcInterface interface { FileAppendIJsonCommand(ctx context.Context, data CommandAppendIJsonData) error FileWriteCommand(ctx context.Context, data FileData) error FileReadCommand(ctx context.Context, data FileData) (*FileData, error) + FileReadStreamCommand(ctx context.Context, data FileData) <-chan RespOrErrorUnion[FileData] FileStreamTarCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[iochantypes.Packet] FileMoveCommand(ctx context.Context, data CommandFileCopyData) error FileCopyCommand(ctx context.Context, data CommandFileCopyData) error FileInfoCommand(ctx context.Context, data FileData) (*FileInfo, error) FileListCommand(ctx context.Context, data FileListData) ([]*FileInfo, error) + FileJoinCommand(ctx context.Context, paths []string) (*FileInfo, error) FileListStreamCommand(ctx context.Context, data FileListData) <-chan RespOrErrorUnion[CommandRemoteListEntriesRtnData] + + FileShareCapabilityCommand(ctx context.Context, path string) (FileShareCapability, error) EventPublishCommand(ctx context.Context, data wps.WaveEvent) error EventSubCommand(ctx context.Context, data wps.SubscriptionRequest) error EventUnsubCommand(ctx context.Context, data string) error @@ -204,6 +215,7 @@ type WshRpcInterface interface { ConnConnectCommand(ctx context.Context, connRequest ConnRequest) error ConnDisconnectCommand(ctx context.Context, connName string) error ConnListCommand(ctx context.Context) ([]string, error) + ConnListAWSCommand(ctx context.Context) ([]string, error) WslListCommand(ctx context.Context) ([]string, error) WslDefaultDistroCommand(ctx context.Context) (string, error) DismissWshFailCommand(ctx context.Context, connName string) error @@ -215,11 +227,11 @@ type WshRpcInterface interface { // remotes RemoteStreamFileCommand(ctx context.Context, data CommandRemoteStreamFileData) chan RespOrErrorUnion[FileData] RemoteTarStreamCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[iochantypes.Packet] - RemoteFileCopyCommand(ctx context.Context, data CommandRemoteFileCopyData) error + RemoteFileCopyCommand(ctx context.Context, data CommandFileCopyData) error RemoteListEntriesCommand(ctx context.Context, data CommandRemoteListEntriesData) chan RespOrErrorUnion[CommandRemoteListEntriesRtnData] RemoteFileInfoCommand(ctx context.Context, path string) (*FileInfo, error) RemoteFileTouchCommand(ctx context.Context, path string) error - RemoteFileMoveCommand(ctx context.Context, data CommandRemoteFileCopyData) error + RemoteFileMoveCommand(ctx context.Context, data CommandFileCopyData) error RemoteFileDeleteCommand(ctx context.Context, data CommandDeleteFileData) error RemoteWriteFileCommand(ctx context.Context, data FileData) error RemoteFileJoinCommand(ctx context.Context, paths []string) (*FileInfo, error) @@ -526,12 +538,6 @@ type CommandFileCopyData struct { Opts *FileCopyOpts `json:"opts,omitempty"` } -type CommandRemoteFileCopyData struct { - SrcUri string `json:"srcuri"` - DestUri string `json:"desturi"` - Opts *FileCopyOpts `json:"opts,omitempty"` -} - type CommandRemoteStreamTarData struct { Path string `json:"path"` Opts *FileCopyOpts `json:"opts,omitempty"` @@ -539,8 +545,8 @@ type CommandRemoteStreamTarData struct { type FileCopyOpts struct { Overwrite bool `json:"overwrite,omitempty"` - Recursive bool `json:"recursive,omitempty"` - Merge bool `json:"merge,omitempty"` + Recursive bool `json:"recursive,omitempty"` // only used for move, always true for copy + Merge bool `json:"merge,omitempty"` // only used for copy, always false for move Timeout int64 `json:"timeout,omitempty"` } @@ -764,3 +770,11 @@ type SuggestionType struct { FileName string `json:"file:name,omitempty"` UrlUrl string `json:"url:url,omitempty"` } + +// FileShareCapability represents the capabilities of a file share +type FileShareCapability struct { + // CanAppend indicates whether the file share supports appending to files + CanAppend bool `json:"canappend"` + // CanMkdir indicates whether the file share supports creating directories + CanMkdir bool `json:"canmkdir"` +} diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index 84af61fd2b..3a2e8768f4 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -24,6 +24,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/genconn" "github.com/wavetermdev/waveterm/pkg/panichandler" "github.com/wavetermdev/waveterm/pkg/remote" + "github.com/wavetermdev/waveterm/pkg/remote/awsconn" "github.com/wavetermdev/waveterm/pkg/remote/conncontroller" "github.com/wavetermdev/waveterm/pkg/remote/fileshare" "github.com/wavetermdev/waveterm/pkg/suggestion" @@ -31,6 +32,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/telemetry/telemetrydata" "github.com/wavetermdev/waveterm/pkg/util/envutil" "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/iterfn" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/util/wavefileutil" @@ -383,6 +385,10 @@ func (ws *WshServer) FileReadCommand(ctx context.Context, data wshrpc.FileData) return fileshare.Read(ctx, data) } +func (ws *WshServer) FileReadStreamCommand(ctx context.Context, data wshrpc.FileData) <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { + return fileshare.ReadStream(ctx, data) +} + func (ws *WshServer) FileCopyCommand(ctx context.Context, data wshrpc.CommandFileCopyData) error { return fileshare.Copy(ctx, data) } @@ -424,6 +430,20 @@ func (ws *WshServer) FileAppendIJsonCommand(ctx context.Context, data wshrpc.Com return nil } +func (ws *WshServer) FileJoinCommand(ctx context.Context, paths []string) (*wshrpc.FileInfo, error) { + if len(paths) < 2 { + if len(paths) == 0 { + return nil, fmt.Errorf("no paths provided") + } + return fileshare.Stat(ctx, paths[0]) + } + return fileshare.Join(ctx, paths[0], paths[1:]...) +} + +func (ws *WshServer) FileShareCapabilityCommand(ctx context.Context, path string) (wshrpc.FileShareCapability, error) { + return fileshare.GetCapability(ctx, path) +} + func (ws *WshServer) DeleteSubBlockCommand(ctx context.Context, data wshrpc.CommandDeleteBlockData) error { err := wcore.DeleteBlock(ctx, data.BlockId, false) if err != nil { @@ -550,6 +570,15 @@ func termCtxWithLogBlockId(ctx context.Context, logBlockId string) context.Conte } func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtData) error { + // TODO: if we add proper wsh connections via aws, we'll need to handle that here + if strings.HasPrefix(data.ConnName, "aws:") { + profiles := awsconn.ParseProfiles() + for profile := range profiles { + if strings.HasPrefix(data.ConnName, profile) { + return nil + } + } + } ctx = genconn.ContextWithConnData(ctx, data.LogBlockId) ctx = termCtxWithLogBlockId(ctx, data.LogBlockId) if strings.HasPrefix(data.ConnName, "wsl://") { @@ -560,6 +589,10 @@ func (ws *WshServer) ConnEnsureCommand(ctx context.Context, data wshrpc.ConnExtD } func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) error { + // TODO: if we add proper wsh connections via aws, we'll need to handle that here + if strings.HasPrefix(connName, "aws:") { + return nil + } if strings.HasPrefix(connName, "wsl://") { distroName := strings.TrimPrefix(connName, "wsl://") conn := wslconn.GetWslConn(distroName) @@ -580,6 +613,10 @@ func (ws *WshServer) ConnDisconnectCommand(ctx context.Context, connName string) } func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc.ConnRequest) error { + // TODO: if we add proper wsh connections via aws, we'll need to handle that here + if strings.HasPrefix(connRequest.Host, "aws:") { + return nil + } ctx = genconn.ContextWithConnData(ctx, connRequest.LogBlockId) ctx = termCtxWithLogBlockId(ctx, connRequest.LogBlockId) connName := connRequest.Host @@ -603,6 +640,10 @@ func (ws *WshServer) ConnConnectCommand(ctx context.Context, connRequest wshrpc. } func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.ConnExtData) error { + // TODO: if we add proper wsh connections via aws, we'll need to handle that here + if strings.HasPrefix(data.ConnName, "aws:") { + return nil + } ctx = genconn.ContextWithConnData(ctx, data.LogBlockId) ctx = termCtxWithLogBlockId(ctx, data.LogBlockId) connName := data.ConnName @@ -672,6 +713,11 @@ func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) { return conncontroller.GetConnectionsList() } +func (ws *WshServer) ConnListAWSCommand(ctx context.Context) ([]string, error) { + profilesMap := awsconn.ParseProfiles() + return iterfn.MapKeysToSorted(profilesMap), nil +} + func (ws *WshServer) WslListCommand(ctx context.Context) ([]string, error) { distros, err := wsl.RegisteredDistros(ctx) if err != nil { diff --git a/pkg/wshutil/wshrpc.go b/pkg/wshutil/wshrpc.go index 400d0070ce..128987137e 100644 --- a/pkg/wshutil/wshrpc.go +++ b/pkg/wshutil/wshrpc.go @@ -730,9 +730,7 @@ func (w *WshRpc) setServerDone() { defer w.Lock.Unlock() w.ServerDone = true close(w.CtxDoneCh) - for range w.CtxDoneCh { - // drain channel - } + utilfn.DrainChannelSafe(w.InputCh, "wshrpc.setServerDone") } func (w *WshRpc) retrySendTimeout(resId string) { diff --git a/pkg/wshutil/wshrpcio.go b/pkg/wshutil/wshrpcio.go index 9aa5f1609b..7db864626b 100644 --- a/pkg/wshutil/wshrpcio.go +++ b/pkg/wshutil/wshrpcio.go @@ -25,10 +25,7 @@ func AdaptOutputChToStream(outputCh chan []byte, output io.Writer) error { drain := false defer func() { if drain { - go func() { - for range outputCh { - } - }() + utilfn.DrainChannelSafe(outputCh, "AdaptOutputChToStream") } }() for msg := range outputCh { diff --git a/tests/copytests/cases/test026.sh b/tests/copytests/cases/test026.sh index e6bfcb3617..26655cce70 100755 --- a/tests/copytests/cases/test026.sh +++ b/tests/copytests/cases/test026.sh @@ -7,9 +7,9 @@ cd "$HOME/testcp" touch foo.txt # this is different from cp behavior -wsh file copy foo.txt baz/ >/dev/null 2>&1 && echo "command should have failed" && exit 1 +wsh file copy foo.txt baz/ -if [ -f baz/foo.txt ]; then - echo "baz/foo.txt should not exist" +if [ ! -f baz/foo.txt ]; then + echo "baz/foo.txt does not exist" exit 1 fi diff --git a/tests/copytests/cases/test048.sh b/tests/copytests/cases/test048.sh new file mode 100755 index 0000000000..9f86932d28 --- /dev/null +++ b/tests/copytests/cases/test048.sh @@ -0,0 +1,19 @@ +# copy the current directory into an existing directory +# ensure the copy succeeds and the output exists + +set -e +cd "$HOME/testcp" +mkdir foo +touch foo/bar.txt +mkdir baz +cd foo + + +wsh file copy . ../baz +cd .. + + +if [ ! -f baz/bar.txt ]; then + echo "baz/bar.txt does not exist" + exit 1 +fi diff --git a/tests/copytests/cases/test049.sh b/tests/copytests/cases/test049.sh index 51d309a959..3008c14653 100755 --- a/tests/copytests/cases/test049.sh +++ b/tests/copytests/cases/test049.sh @@ -1,11 +1,10 @@ -# copy the current directory into an existing directory +# copy the current directory into a non-existing directory # ensure the copy succeeds and the output exists set -e cd "$HOME/testcp" mkdir foo touch foo/bar.txt -mkdir baz cd foo wsh file copy . ../baz