Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions v2/common/websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (c *client) wait(timeout time.Duration) {

// handleReconnect waits for reconnect signal and starts reconnect
func (c *client) handleReconnect() {
for _ = range c.reconnectSignal {
for range c.reconnectSignal {
c.debug("reconnect: received signal")

b := &backoff.Backoff{
Expand Down Expand Up @@ -337,13 +337,17 @@ func NewConnection(
return nil, err
}

ctx, cancel := context.WithCancel(context.Background())

wsConn := &connection{
conn: underlyingWsConn,
connectionMu: sync.Mutex{},
lastResponseMu: sync.Mutex{},
initUnderlyingWsConnFn: initUnderlyingWsConnFn,
keepaliveTimeout: keepaliveTimeout,
isKeepAliveNeeded: isKeepAliveNeeded,
ctx: ctx,
cancel: cancel,
}

if isKeepAliveNeeded {
Expand All @@ -362,6 +366,8 @@ type connection struct {
initUnderlyingWsConnFn func() (*websocket.Conn, error)
keepaliveTimeout time.Duration
isKeepAliveNeeded bool
ctx context.Context
cancel context.CancelFunc
}

type Connection interface {
Expand All @@ -379,7 +385,11 @@ func (c *connection) WriteMessage(messageType int, data []byte) error {

// ReadMessage wrapper for conn.ReadMessage
func (c *connection) ReadMessage() (int, []byte, error) {
return c.conn.ReadMessage()
msgType, msg, err := c.conn.ReadMessage()
if err != nil {
c.cancel()
}
return msgType, msg, err
}

// RestoreConnection recreates ws connection with the same underlying connection callback and keepalive timeout
Expand All @@ -389,8 +399,6 @@ func (c *connection) RestoreConnection() (Connection, error) {

// keepAlive handles ping-pong for connection
func (c *connection) keepAlive(timeout time.Duration) {
ticker := time.NewTicker(timeout)

c.updateLastResponse()

c.conn.SetPongHandler(func(msg string) error {
Expand All @@ -399,17 +407,23 @@ func (c *connection) keepAlive(timeout time.Duration) {
})

go func() {
ticker := time.NewTicker(timeout)
defer ticker.Stop()
for {
err := c.ping()
if err != nil {
return
}

<-ticker.C
if c.isLastResponseOutdated(timeout) {
c.close()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
err := c.ping()
if err != nil {
return
}

if c.isLastResponseOutdated(timeout) {
c.close()
return
}
}
}
}()
Expand Down Expand Up @@ -442,10 +456,5 @@ func (c *connection) ping() error {
defer c.connectionMu.Unlock()

deadline := time.Now().Add(KeepAlivePingDeadline)
err := c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
if err != nil {
return err
}

return nil
return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
}
18 changes: 11 additions & 7 deletions v2/delivery/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
// closed by the client.
defer close(doneC)
if WebsocketKeepalive {
keepAlive(c, WebsocketTimeout)
keepAlive(doneC, c, WebsocketTimeout)
}
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
Expand Down Expand Up @@ -83,9 +83,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
return
}

func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) {
lastResponse := time.Now()

c.SetPingHandler(func(pingData string) error {
Expand All @@ -105,12 +103,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
})

go func() {
ticker := time.NewTicker(timeout)
defer ticker.Stop()

for {
<-ticker.C
if time.Since(lastResponse) > timeout {
c.Close()
select {
case <-done:
return
case <-ticker.C:
if time.Since(lastResponse) > timeout {
c.Close()
return
}
}
}
}()
Expand Down
18 changes: 11 additions & 7 deletions v2/futures/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
// closed by the client.
defer close(doneC)
if WebsocketKeepalive {
keepAlive(c, WebsocketTimeout)
keepAlive(doneC, c, WebsocketTimeout)
}
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
Expand Down Expand Up @@ -83,9 +83,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
return
}

func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) {
lastResponse := time.Now()

c.SetPingHandler(func(pingData string) error {
Expand All @@ -105,12 +103,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
})

go func() {
ticker := time.NewTicker(timeout)
defer ticker.Stop()

for {
<-ticker.C
if time.Since(lastResponse) > timeout {
c.Close()
select {
case <-done:
return
case <-ticker.C:
if time.Since(lastResponse) > timeout {
c.Close()
return
}
}
}
}()
Expand Down
18 changes: 11 additions & 7 deletions v2/options/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
// closed by the client.
defer close(doneC)
if WebsocketKeepalive {
keepAlive(c, WebsocketTimeout)
keepAlive(doneC, c, WebsocketTimeout)
}
// Wait for the stopC channel to be closed. We do that in a
// separate goroutine because ReadMessage is a blocking
Expand Down Expand Up @@ -84,9 +84,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
return
}

func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) {
lastResponse := time.Now()

c.SetPingHandler(func(pingData string) error {
Expand All @@ -106,12 +104,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
})

go func() {
ticker := time.NewTicker(timeout)
defer ticker.Stop()

for {
<-ticker.C
if time.Since(lastResponse) > timeout {
c.Close()
select {
case <-done:
return
case <-ticker.C:
if time.Since(lastResponse) > timeout {
c.Close()
return
}
}
}
}()
Expand Down
18 changes: 11 additions & 7 deletions v2/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
if WebsocketKeepalive {
// This function overwrites the default ping frame handler
// sent by the websocket API server
keepAlive(c, WebsocketTimeout)
keepAlive(doneC, c, WebsocketTimeout)
}

// Wait for the stopC channel to be closed. We do that in a
Expand Down Expand Up @@ -87,9 +87,7 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
return
}

func keepAlive(c *websocket.Conn, timeout time.Duration) {
ticker := time.NewTicker(timeout)

func keepAlive(done chan struct{}, c *websocket.Conn, timeout time.Duration) {
lastResponse := time.Now()

c.SetPingHandler(func(pingData string) error {
Expand All @@ -109,12 +107,18 @@ func keepAlive(c *websocket.Conn, timeout time.Duration) {
})

go func() {
ticker := time.NewTicker(timeout)
defer ticker.Stop()

for {
<-ticker.C
if time.Since(lastResponse) > timeout {
c.Close()
select {
case <-done:
return
case <-ticker.C:
if time.Since(lastResponse) > timeout {
c.Close()
return
}
}
}
}()
Expand Down
Loading