@@ -49,13 +49,6 @@ type Conn struct {
4949 readDone chan int
5050}
5151
52- func (c * Conn ) getCloseErr () error {
53- if c .closeErr != nil {
54- return c .closeErr
55- }
56- return nil
57- }
58-
5952func (c * Conn ) close (err error ) {
6053 if err != nil {
6154 err = xerrors .Errorf ("websocket: connection broken: %w" , err )
@@ -160,8 +153,12 @@ messageLoop:
160153 masked : c .client ,
161154 }
162155 c .writeFrame (h , control .payload )
163- c .writeDone <- struct {}{}
164- continue
156+ select {
157+ case <- c .closed :
158+ return
159+ case c .writeDone <- struct {}{}:
160+ continue
161+ }
165162 case b , ok := <- c .writeBytes :
166163 h := header {
167164 fin : ! ok ,
@@ -349,14 +346,14 @@ func (c *Conn) Close(code StatusCode, reason string) error {
349346 p , _ = closePayload (StatusInternalError , fmt .Sprintf ("websocket: application tried to send code %v but code or reason was invalid" , code ))
350347 }
351348
352- err2 := c .writeClose (p , CloseError {
349+ cerr := c .writeClose (p , CloseError {
353350 Code : code ,
354351 Reason : reason ,
355352 })
356353 if err != nil {
357354 return err
358355 }
359- return err2
356+ return cerr
360357}
361358
362359func (c * Conn ) writeClose (p []byte , cerr CloseError ) error {
@@ -381,19 +378,19 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
381378func (c * Conn ) writeControl (ctx context.Context , opcode opcode , p []byte ) error {
382379 select {
383380 case <- c .closed :
384- return c .getCloseErr ()
381+ return c .closeErr
385382 case c .control <- control {
386383 opcode : opcode ,
387384 payload : p ,
388385 }:
389386 case <- ctx .Done ():
390387 c .close (xerrors .New ("force closed: close frame write timed out" ))
391- return c .getCloseErr ()
388+ return c .closeErr
392389 }
393390
394391 select {
395392 case <- c .closed :
396- return c .getCloseErr ()
393+ return c .closeErr
397394 case <- c .writeDone :
398395 return nil
399396 case <- ctx .Done ():
@@ -420,34 +417,25 @@ type messageWriter struct {
420417 ctx context.Context
421418 c * Conn
422419 acquiredLock bool
423- sentFirst bool
424-
425- done chan struct {}
426420}
427421
428422// Write writes the given bytes to the WebSocket connection.
429423// The frame will automatically be fragmented as appropriate
430424// with the buffers obtained from http.Hijacker.
431425// Please ensure you call Close once you have written the full message.
432426func (w * messageWriter ) Write (p []byte ) (int , error ) {
433- if ! w .acquiredLock {
434- select {
435- case <- w .c .closed :
436- return 0 , w .c .getCloseErr ()
437- case w .c .write <- w .datatype :
438- w .acquiredLock = true
439- case <- w .ctx .Done ():
440- return 0 , w .ctx .Err ()
441- }
427+ err := w .acquire ()
428+ if err != nil {
429+ return 0 , err
442430 }
443431
444432 select {
445433 case <- w .c .closed :
446- return 0 , w .c .getCloseErr ()
434+ return 0 , w .c .closeErr
447435 case w .c .writeBytes <- p :
448436 select {
449437 case <- w .c .closed :
450- return 0 , w .c .getCloseErr ()
438+ return 0 , w .c .closeErr
451439 case <- w .c .writeDone :
452440 return len (p ), nil
453441 case <- w .ctx .Done ():
@@ -458,23 +446,32 @@ func (w *messageWriter) Write(p []byte) (int, error) {
458446 }
459447}
460448
461- // Close flushes the frame to the connection.
462- // This must be called for every messageWriter.
463- func (w * messageWriter ) Close () error {
449+ func (w * messageWriter ) acquire () error {
464450 if ! w .acquiredLock {
465451 select {
466452 case <- w .c .closed :
467- return w .c .getCloseErr ()
453+ return w .c .closeErr
468454 case w .c .write <- w .datatype :
469455 w .acquiredLock = true
470456 case <- w .ctx .Done ():
471457 return w .ctx .Err ()
472458 }
473459 }
460+ return nil
461+ }
462+
463+ // Close flushes the frame to the connection.
464+ // This must be called for every messageWriter.
465+ func (w * messageWriter ) Close () error {
466+ err := w .acquire ()
467+ if err != nil {
468+ return err
469+ }
470+
474471 close (w .c .writeBytes )
475472 select {
476473 case <- w .c .closed :
477- return w .c .getCloseErr ()
474+ return w .c .closeErr
478475 case <- w .ctx .Done ():
479476 return w .ctx .Err ()
480477 case <- w .c .writeDone :
@@ -490,7 +487,7 @@ func (w *messageWriter) Close() error {
490487func (c * Conn ) Read (ctx context.Context ) (DataType , io.Reader , error ) {
491488 select {
492489 case <- c .closed :
493- return 0 , nil , xerrors .Errorf ("failed to read message: %w" , c .getCloseErr () )
490+ return 0 , nil , xerrors .Errorf ("failed to read message: %w" , c .closeErr )
494491 case opcode := <- c .read :
495492 return DataType (opcode ), & messageReader {
496493 ctx : ctx ,
@@ -507,24 +504,17 @@ type messageReader struct {
507504 c * Conn
508505}
509506
510- // SetContext bounds the read operation to the ctx.
511- // By default, the context is the one passed to conn.ReadMessage.
512- // You still almost always want a separate context for reading the message though.
513- func (r * messageReader ) SetContext (ctx context.Context ) {
514- r .ctx = ctx
515- }
516-
517507// Read reads as many bytes as possible into p.
518508func (r * messageReader ) Read (p []byte ) (n int , err error ) {
519509 select {
520510 case <- r .c .closed :
521- return 0 , r .c .getCloseErr ()
511+ return 0 , r .c .closeErr
522512 case <- r .c .readDone :
523513 return 0 , io .EOF
524514 case r .c .readBytes <- p :
525515 select {
526516 case <- r .c .closed :
527- return 0 , r .c .getCloseErr ()
517+ return 0 , r .c .closeErr
528518 case n := <- r .c .readDone :
529519 return n , nil
530520 case <- r .ctx .Done ():
0 commit comments