From 5b28824235b378dc7cb675f3547efddc500bc531 Mon Sep 17 00:00:00 2001 From: Natalie Gaston Date: Tue, 3 Mar 2026 12:39:02 -0800 Subject: [PATCH] fix(kvm): close browser websocket immediately on AMT disconnect --- internal/usecase/devices/interceptor.go | 10 ++- .../devices/interceptor_private_test.go | 88 +++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/internal/usecase/devices/interceptor.go b/internal/usecase/devices/interceptor.go index faa387ddb..74c12d485 100644 --- a/internal/usecase/devices/interceptor.go +++ b/internal/usecase/devices/interceptor.go @@ -215,7 +215,13 @@ func (uc *UseCase) ListenToDevice(deviceConnection *DeviceConnection) { conn := deviceConnection.Conn defer func() { - // Clean up on exit + // Notify the browser immediately so the UI updates without waiting for + // ListenToBrowser to unblock on its ReadMessage call. + _ = deviceConnection.Conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "AMT session ended"), + ) + _ = deviceConnection.Conn.Close() deviceConnection.cancel() }() @@ -650,7 +656,7 @@ func writeLength(buf *bytes.Buffer, challenge *client.AuthChallenge, response st return ErrLengthLimit // If total length is too large, throws an error and stops here } - length := uint32(totalLength) //nolint:gosec // Ignore potential integer overflow here as overflow is validated earlier in code + length := uint32(totalLength) // overflow validated above return binary.Write(buf, binary.LittleEndian, length) } diff --git a/internal/usecase/devices/interceptor_private_test.go b/internal/usecase/devices/interceptor_private_test.go index eded32989..166fe4baa 100644 --- a/internal/usecase/devices/interceptor_private_test.go +++ b/internal/usecase/devices/interceptor_private_test.go @@ -2,15 +2,22 @@ package devices import ( "bytes" + "context" "errors" "io" "math" "strings" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/require" + "github.com/device-management-toolkit/go-wsman-messages/v2/pkg/wsman" "github.com/device-management-toolkit/go-wsman-messages/v2/pkg/wsman/client" + + "github.com/device-management-toolkit/console/internal/entity" + "github.com/device-management-toolkit/console/pkg/logger" ) func TestProcessBrowserData(t *testing.T) { @@ -834,3 +841,84 @@ func TestRandomValueHexErrorCase(t *testing.T) { require.NoError(t, err) require.Empty(t, result) } + +// spyRedirection is a minimal Redirection stub whose RedirectListen returns +// an error to simulate AMT dropping the TCP connection. +type spyRedirection struct { + listenErr error +} + +func (s *spyRedirection) SetupWsmanClient(_ entity.Device, _, _ bool) (wsman.Messages, error) { + return wsman.Messages{}, nil +} + +func (s *spyRedirection) RedirectConnect(_ context.Context, _ *DeviceConnection) error { return nil } +func (s *spyRedirection) RedirectClose(_ context.Context, _ *DeviceConnection) error { return nil } +func (s *spyRedirection) RedirectSend(_ context.Context, _ *DeviceConnection, _ []byte) error { + return nil +} + +func (s *spyRedirection) RedirectListen(_ context.Context, _ *DeviceConnection) ([]byte, error) { + return nil, s.listenErr +} + +// spyWebSocketConn is a minimal spy that records WriteMessage and Close calls +// without importing the mocks package (which would create an import cycle for +// the internal test package). +type spyWebSocketConn struct { + writeMessageCalled bool + writeMessageType int + closeCalled bool +} + +func (s *spyWebSocketConn) WriteMessage(messageType int, _ []byte) error { + s.writeMessageCalled = true + s.writeMessageType = messageType + + return nil +} + +func (s *spyWebSocketConn) ReadMessage() (messageType int, p []byte, err error) { + return 0, nil, errors.New("spy: not reading") +} + +func (s *spyWebSocketConn) Close() error { + s.closeCalled = true + + return nil +} + +func TestListenToDeviceClosesWebSocketOnAMTDisconnect(t *testing.T) { + t.Parallel() + + spy := &spyWebSocketConn{} + ctx, cancel := context.WithCancel(context.Background()) + + deviceConnection := &DeviceConnection{ + Conn: spy, + Mode: "kvm", + Device: entityDevice(), + ctx: ctx, + cancel: cancel, + healthTicker: time.NewTicker(HeartbeatInterval), + } + + uc := &UseCase{ + redirection: &spyRedirection{listenErr: errors.New("connection reset by peer")}, + redirConnections: make(map[string]*DeviceConnection), + log: logger.New("silent"), + } + + uc.ListenToDevice(deviceConnection) + + require.True(t, spy.writeMessageCalled, "expected WriteMessage(CloseMessage) to be called on browser WebSocket") + require.Equal(t, websocket.CloseMessage, spy.writeMessageType) + require.True(t, spy.closeCalled, "expected Close() to be called on browser WebSocket") +} + +func entityDevice() entity.Device { + return entity.Device{ + GUID: "test-guid", + Username: "admin", + } +}