Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion cmd/wsh/cmd/wshcmd-connserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"sync/atomic"
"syscall"
"time"

"github.com/spf13/cobra"
Expand All @@ -33,9 +34,11 @@ var serverCmd = &cobra.Command{
}

var connServerRouter bool
var singleServerRouter bool

func init() {
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
serverCmd.Flags().BoolVar(&singleServerRouter, "single", false, "run in local single mode")
rootCmd.AddCommand(serverCmd)
}

Expand Down Expand Up @@ -186,6 +189,39 @@ func serverRunRouter() error {
select {}
}

func checkForUpdate() error {
remoteInfo := wshutil.GetInfo(RpcContext)
needsRestartRaw, err := RpcClient.SendRpcRequest(wshrpc.Command_ConnUpdateWsh, remoteInfo, &wshrpc.RpcOpts{Timeout: 60000})
if err != nil {
return fmt.Errorf("could not update: %w", err)
}
needsRestart, ok := needsRestartRaw.(bool)
if !ok {
return fmt.Errorf("wrong return type from update")
}
if needsRestart {
// run the restart command here
// how to get the correct path?
return syscall.Exec("~/.waveterm/bin/wsh", []string{"wsh", "connserver", "--single"}, []string{})
}
return nil
}

func serverRunSingle() error {
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
if err != nil {
return err
}
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
err = checkForUpdate()
if err != nil {
return err
}

go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
select {} // run forever
}

func serverRunNormal() error {
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
if err != nil {
Expand All @@ -197,7 +233,9 @@ func serverRunNormal() error {
}

func serverRun(cmd *cobra.Command, args []string) error {
if connServerRouter {
if singleServerRouter {
return serverRunSingle()
} else if connServerRouter {
return serverRunRouter()
} else {
return serverRunNormal()
Expand Down
5 changes: 5 additions & 0 deletions frontend/app/store/wshclientapi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ class RpcApiType {
return client.wshRpcCall("connstatus", null, opts);
}

// command "connupdatewsh" [call]
ConnUpdateWshCommand(client: WshClient, data: RemoteInfo, opts?: RpcOpts): Promise<boolean> {
return client.wshRpcCall("connupdatewsh", data, opts);
}

// command "controllerappendoutput" [call]
ControllerAppendOutputCommand(client: WshClient, data: CommandControllerAppendOutputData, opts?: RpcOpts): Promise<void> {
return client.wshRpcCall("controllerappendoutput", data, opts);
Expand Down
9 changes: 9 additions & 0 deletions frontend/types/gotypes.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,15 @@ declare global {
y: number;
};

// wshrpc.RemoteInfo
type RemoteInfo = {
host: string;
clientarch: string;
clientos: string;
clientversion: string;
shell: string;
};

// wshutil.RpcMessage
type RpcMessage = {
command?: string;
Expand Down
20 changes: 18 additions & 2 deletions pkg/remote/conncontroller/conncontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func (conn *SSHConn) OpenDomainSocketListener(ctx context.Context) error {
// expects the output of `wsh version` which looks like `wsh v0.10.4` or "not-installed"
// returns (up-to-date, semver, error)
// if not up to date, or error, version might be ""
func isWshVersionUpToDate(wshVersionLine string) (bool, string, error) {
func IsWshVersionUpToDate(wshVersionLine string) (bool, string, error) {
wshVersionLine = strings.TrimSpace(wshVersionLine)
if wshVersionLine == "not-installed" {
return false, "", nil
Expand Down Expand Up @@ -290,7 +290,7 @@ func (conn *SSHConn) StartConnServer(ctx context.Context) (bool, string, error)
return false, "", fmt.Errorf("error reading wsh version: %w", err)
}
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
isUpToDate, clientVersion, err := isWshVersionUpToDate(versionLine)
isUpToDate, clientVersion, err := IsWshVersionUpToDate(versionLine)
if err != nil {
sshSession.Close()
return false, "", fmt.Errorf("error checking wsh version: %w", err)
Expand Down Expand Up @@ -377,6 +377,22 @@ to ensure a seamless experience.
Would you like to install them?
`)

func (conn *SSHConn) UpdateWsh(ctx context.Context, clientDisplayName string, remoteInfo *wshrpc.RemoteInfo) error {
conn.Infof(ctx, "attempting to update wsh for connection %s (os:%s arch:%s version:%s)\n",
remoteInfo.ConnName, remoteInfo.ClientOs, remoteInfo.ClientArch, remoteInfo.ClientVersion)
client := conn.GetClient()
if client == nil {
return fmt.Errorf("cannot update wsh: ssh client is not connected")
}
err := remote.CpWshToRemote(ctx, client, remoteInfo.ClientOs, remoteInfo.ClientArch)
if err != nil {
return fmt.Errorf("error installing wsh to remote: %w", err)
}
conn.Infof(ctx, "successfully updated wsh on %s\n", conn.GetName())
return nil

}

// returns (allowed, error)
func (conn *SSHConn) getPermissionToInstallWsh(ctx context.Context, clientDisplayName string) (bool, error) {
conn.Infof(ctx, "running getPermissionToInstallWsh...\n")
Expand Down
6 changes: 6 additions & 0 deletions pkg/wshrpc/wshclient/wshclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ func ConnStatusCommand(w *wshutil.WshRpc, opts *wshrpc.RpcOpts) ([]wshrpc.ConnSt
return resp, err
}

// command "connupdatewsh", wshserver.ConnUpdateWshCommand
func ConnUpdateWshCommand(w *wshutil.WshRpc, data wshrpc.RemoteInfo, opts *wshrpc.RpcOpts) (bool, error) {
resp, err := sendRpcRequestCallHelper[bool](w, "connupdatewsh", data, opts)
return resp, err
}

// command "controllerappendoutput", wshserver.ControllerAppendOutputCommand
func ControllerAppendOutputCommand(w *wshutil.WshRpc, data wshrpc.CommandControllerAppendOutputData, opts *wshrpc.RpcOpts) error {
_, err := sendRpcRequestCallHelper[any](w, "controllerappendoutput", data, opts)
Expand Down
10 changes: 10 additions & 0 deletions pkg/wshrpc/wshrpctypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ const (
Command_WslList = "wsllist"
Command_WslDefaultDistro = "wsldefaultdistro"
Command_DismissWshFail = "dismisswshfail"
Command_ConnUpdateWsh = "updatewsh"

Command_WorkspaceList = "workspacelist"

Expand Down Expand Up @@ -163,6 +164,7 @@ type WshRpcInterface interface {
WslListCommand(ctx context.Context) ([]string, error)
WslDefaultDistroCommand(ctx context.Context) (string, error)
DismissWshFailCommand(ctx context.Context, connName string) error
ConnUpdateWshCommand(ctx context.Context, remoteInfo RemoteInfo) (bool, error)

// eventrecv is special, it's handled internally by WshRpc with EventListener
EventRecvCommand(ctx context.Context, data wps.WaveEvent) error
Expand Down Expand Up @@ -500,6 +502,14 @@ type ConnRequest struct {
LogBlockId string `json:"logblockid,omitempty"`
}

type RemoteInfo struct {
ConnName string `json:"host"`
ClientArch string `json:"clientarch"`
ClientOs string `json:"clientos"`
ClientVersion string `json:"clientversion"`
Shell string `json:"shell"`
}

const (
TimeSeries_Cpu = "cpu"
)
Expand Down
39 changes: 39 additions & 0 deletions pkg/wshrpc/wshserver/wshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,45 @@ func (ws *WshServer) ConnReinstallWshCommand(ctx context.Context, data wshrpc.Co
return conn.InstallWsh(ctx)
}

func (ws *WshServer) ConnUpdateWshCommand(ctx context.Context, remoteInfo wshrpc.RemoteInfo) (bool, error) {
connName := remoteInfo.ConnName
if connName == "" {
return false, fmt.Errorf("invalid remote info: missing connection name")
}

log.Printf("checking wsh version for connection %s (current: %s)", connName, remoteInfo.ClientVersion)
upToDate, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion)
if err != nil {
return false, fmt.Errorf("unable to compare wsh version: %w", err)
}
Comment on lines +709 to +712
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect usage of IsWshVersionUpToDate function

The IsWshVersionUpToDate function expects a version line in the format "wsh v0.10.4" or "not-installed", but remoteInfo.ClientVersion likely contains only the version string (e.g., "v0.10.4"). Passing remoteInfo.ClientVersion directly may cause incorrect parsing and version comparison.

Apply this diff to adjust the version string before calling IsWshVersionUpToDate:

     log.Printf("checking wsh version for connection %s (current: %s)", connName, remoteInfo.ClientVersion)
-    upToDate, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion)
+    versionLine := fmt.Sprintf("wsh %s", remoteInfo.ClientVersion)
+    upToDate, _, err := conncontroller.IsWshVersionUpToDate(versionLine)
     if err != nil {
         return false, fmt.Errorf("unable to compare wsh version: %w", err)
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
upToDate, _, err := conncontroller.IsWshVersionUpToDate(remoteInfo.ClientVersion)
if err != nil {
return false, fmt.Errorf("unable to compare wsh version: %w", err)
}
versionLine := fmt.Sprintf("wsh %s", remoteInfo.ClientVersion)
upToDate, _, err := conncontroller.IsWshVersionUpToDate(versionLine)
if err != nil {
return false, fmt.Errorf("unable to compare wsh version: %w", err)
}

if upToDate {
// no need to update
log.Printf("wsh is already up to date for connection %s", connName)
return false, nil
}

// todo: need to add user input code here for validation

if strings.HasPrefix(connName, "wsl://") {
return false, fmt.Errorf("connupdatewshcommand is not supported for wsl connections")
}
connOpts, err := remote.ParseOpts(connName)
if err != nil {
return false, fmt.Errorf("error parsing connection name: %w", err)
}
conn := conncontroller.GetConn(ctx, connOpts, false, &wshrpc.ConnKeywords{})
if conn == nil {
return false, fmt.Errorf("connection not found: %s", connName)
}
err = conn.UpdateWsh(ctx, connName, &remoteInfo)
if err != nil {
return false, fmt.Errorf("wsh update failed for connection %s: %w", connName, err)
}

// todo: need to add code for modifying configs?
return true, nil
}

func (ws *WshServer) ConnListCommand(ctx context.Context) ([]string, error) {
return conncontroller.GetConnectionsList()
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/wshutil/wshutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"os"
"os/signal"
"runtime"
"sync"
"sync/atomic"
"syscall"
Expand Down Expand Up @@ -536,3 +537,14 @@ func ExtractUnverifiedSocketName(tokenStr string) (string, error) {
sockName = wavebase.ExpandHomeDirSafe(sockName)
return sockName, nil
}

func GetInfo(rpcContext wshrpc.RpcContext) wshrpc.RemoteInfo {
return wshrpc.RemoteInfo{
ConnName: rpcContext.Conn,
ClientArch: runtime.GOARCH,
ClientOs: runtime.GOOS,
ClientVersion: wavebase.WaveVersion,
Shell: os.Getenv("SHELL"),
}

}
Loading