diff --git a/go/stream.go b/go/stream.go index ecbe347..6bccb13 100644 --- a/go/stream.go +++ b/go/stream.go @@ -81,6 +81,7 @@ type stream struct { streamCtxCancel context.CancelFunc closeError atomic.Value connStatusCallback func(isConneccted bool, host string, origin string) + connMu sync.Mutex waterMarkMu sync.Mutex waterMark map[string]time.Time @@ -131,18 +132,23 @@ func (c *client) newStream(ctx context.Context, httpClient *http.Client, feedIDs c.config.logInfo("client: failed to connect to origin %s: %s", origins[x], err) errs = append(errs, fmt.Errorf("origin %s: %w", origins[x], err)) // Retry connecting to the origin in the background - go func() { - conn, err := s.newWSconnWithRetry(origins[x]) + go func(st *stream) { + conn, err := st.newWSconnWithRetry(origins[x]) if err != nil { return } - go s.monitorConn(conn) - s.conns = append(s.conns, conn) - }() + go st.monitorConn(conn) + st.connMu.Lock() + st.conns = append(st.conns, conn) + st.connMu.Unlock() + }(s) continue + } else { + s.connMu.Lock() + s.conns = append(s.conns, conn) + s.connMu.Unlock() + go s.monitorConn(conn) } - go s.monitorConn(conn) - s.conns = append(s.conns, conn) } // Only fail if we couldn't connect to ANY origins @@ -159,7 +165,9 @@ func (c *client) newStream(ctx context.Context, httpClient *http.Client, feedIDs return nil, err } go s.monitorConn(conn) + s.connMu.Lock() s.conns = append(s.conns, conn) + s.connMu.Unlock() s.stats.configuredConnections.Add(1) } @@ -220,12 +228,11 @@ func (s *stream) monitorConn(conn *wsConn) { cancel() // `Add(^uint64(0))` will decrement activeConnections s.stats.activeConnections.Add(^uint64(0)) - if s.connStatusCallback != nil { - go s.connStatusCallback(false, conn.host, conn.origin) - } - // check for stream close conditions before reconnect attempts if ctxErr := s.streamCtx.Err(); ctxErr != nil || s.closed.Load() { + if s.connStatusCallback != nil { + s.connStatusCallback(false, conn.host, conn.origin) + } if ctxErr != nil { s.config.logInfo( "client: stream websocket %s context done: %s", @@ -236,6 +243,10 @@ func (s *stream) monitorConn(conn *wsConn) { return } + if s.connStatusCallback != nil { + go s.connStatusCallback(false, conn.host, conn.origin) + } + // reconnect protocol if s.stats.activeConnections.Load() == 0 { s.stats.fullReconnects.Add(1) @@ -345,9 +356,12 @@ func (s *stream) Close() (err error) { s.closingMutex.Lock() defer s.closingMutex.Unlock() + s.connMu.Lock() for x := 0; x < len(s.conns); x++ { _ = s.conns[x].close() } + s.connMu.Unlock() + close(s.output) // return a pending error if err, ok := s.closeError.Load().(error); ok {