Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions pkg/blockcontroller/blockcontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,23 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
}
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return nil, err
if !wslConn.WshEnabled.Load() {
shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return nil, err
}
} else {
shellProc, err = shellexec.StartWslShellProc(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
wslConn.SetWshError(err)
wslConn.WshEnabled.Store(false)
log.Printf("error starting wsl shell proc with wsh: %v", err)
log.Print("attempting install without wsh")
shellProc, err = shellexec.StartWslShellProcNoWsh(ctx, rc.TermSize, cmdStr, cmdOpts, wslConn)
if err != nil {
return nil, err
}
}
}
} else if remoteName != "" {
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
Expand Down
21 changes: 21 additions & 0 deletions pkg/shellexec/shellexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
return pp.Write([]byte(s))
}

func StartWslShellProcNoWsh(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) {
client := conn.GetClient()
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProcNoWsh)")

ecmd := exec.Command("wsl.exe", "~", "-d", client.Name())

if termSize.Rows == 0 || termSize.Cols == 0 {
termSize.Rows = shellutil.DefaultTermRows
termSize.Cols = shellutil.DefaultTermCols
}
if termSize.Rows <= 0 || termSize.Cols <= 0 {
return nil, fmt.Errorf("invalid term size: %v", termSize)
}
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
if err != nil {
return nil, err
}
cmdWrap := MakeCmdWrap(ecmd, cmdPty)
return &ShellProc{Cmd: cmdWrap, ConnName: conn.GetName(), CloseOnce: &sync.Once{}, DoneCh: make(chan any)}, nil
}

func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) {
client := conn.GetClient()
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)")
Expand Down
10 changes: 10 additions & 0 deletions pkg/wshrpc/wshserver/wshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,16 @@ func (ws *WshServer) WslDefaultDistroCommand(ctx context.Context) (string, error
* Dismisses the WshFail Command in runtime memory on the backend
*/
func (ws *WshServer) DismissWshFailCommand(ctx context.Context, connName string) error {
if strings.HasPrefix(connName, "wsl://") {
distroName := strings.TrimPrefix(connName, "wsl://")
conn := wslconn.GetWslConn(ctx, distroName, false)
if conn == nil {
return fmt.Errorf("connection not found: %s", connName)
}
conn.ClearWshError()
conn.FireConnChangeEvent()
return nil
}
opts, err := remote.ParseOpts(connName)
if err != nil {
return err
Expand Down
8 changes: 7 additions & 1 deletion pkg/wslconn/wslconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,14 @@ func (conn *WslConn) DeriveConnStatus() wshrpc.ConnStatus {
return wshrpc.ConnStatus{
Status: conn.Status,
Connected: conn.Status == Status_Connected,
WshEnabled: true, // always use wsh for wsl connections (temporary)
WshEnabled: conn.WshEnabled.Load(),
Connection: conn.GetName(),
HasConnected: (conn.LastConnectTime > 0),
ActiveConnNum: conn.ActiveConnNum,
Error: conn.Error,
WshError: conn.WshError,
NoWshReason: conn.NoWshReason,
WshVersion: conn.WshVersion,
}
}

Expand Down Expand Up @@ -702,6 +705,9 @@ func (conn *WslConn) waitForDisconnect() {
log.Printf("wait for disconnect in %+#v", conn)
defer conn.FireConnChangeEvent()
defer conn.HasWaiter.Store(false)
if conn.ConnController == nil {
return
}
err := conn.ConnController.Wait()
conn.WithLock(func() {
// disconnects happen for a variety of reasons (like network, etc. and are typically transient)
Expand Down
Loading