44 "crypto/sha1"
55 "encoding/base64"
66 "net/http"
7+ "net/textproto"
78 "net/url"
89 "strings"
910
@@ -45,56 +46,65 @@ func AcceptOrigins(origins ...string) AcceptOption {
4546 return acceptOrigins (origins )
4647}
4748
48- // Accept accepts a WebSocket handshake from a client and upgrades the
49- // the connection to WebSocket.
50- // Accept will reject the handshake if the Origin is not the same as the Host unless
51- // InsecureAcceptOrigin is passed.
52- // Accept uses w to write the handshake response so the timeouts on the http.Server apply.
53- func Accept (w http.ResponseWriter , r * http.Request , opts ... AcceptOption ) (* Conn , error ) {
54- var subprotocols []string
55- origins := []string {r .Host }
56- for _ , opt := range opts {
57- switch opt := opt .(type ) {
58- case acceptOrigins :
59- origins = []string (opt )
60- case acceptSubprotocols :
61- subprotocols = []string (opt )
62- }
63- }
64-
65- if ! httpguts .HeaderValuesContainsToken (r .Header ["Connection" ], "Upgrade" ) {
49+ func verifyClientRequest (w http.ResponseWriter , r * http.Request ) error {
50+ if ! headerValuesContainsToken (r .Header , "Connection" , "Upgrade" ) {
6651 err := xerrors .Errorf ("websocket: protocol violation: Connection header does not contain Upgrade: %q" , r .Header .Get ("Connection" ))
6752 http .Error (w , err .Error (), http .StatusBadRequest )
68- return nil , err
53+ return err
6954 }
7055
71- if ! httpguts . HeaderValuesContainsToken (r .Header [ "Upgrade" ] , "websocket " ) {
56+ if ! headerValuesContainsToken (r .Header , "Upgrade" , "WebSocket " ) {
7257 err := xerrors .Errorf ("websocket: protocol violation: Upgrade header does not contain websocket: %q" , r .Header .Get ("Upgrade" ))
7358 http .Error (w , err .Error (), http .StatusBadRequest )
74- return nil , err
59+ return err
7560 }
7661
7762 if r .Method != "GET" {
7863 err := xerrors .Errorf ("websocket: protocol violation: handshake request method is not GET: %q" , r .Method )
7964 http .Error (w , err .Error (), http .StatusBadRequest )
80- return nil , err
65+ return err
8166 }
8267
8368 if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
8469 err := xerrors .Errorf ("websocket: unsupported protocol version: %q" , r .Header .Get ("Sec-WebSocket-Version" ))
8570 http .Error (w , err .Error (), http .StatusBadRequest )
86- return nil , err
71+ return err
8772 }
8873
8974 if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
9075 err := xerrors .New ("websocket: protocol violation: missing Sec-WebSocket-Key" )
9176 http .Error (w , err .Error (), http .StatusBadRequest )
77+ return err
78+ }
79+
80+ return nil
81+ }
82+
83+ // Accept accepts a WebSocket handshake from a client and upgrades the
84+ // the connection to WebSocket.
85+ // Accept will reject the handshake if the Origin is not the same as the Host unless
86+ // InsecureAcceptOrigin is passed.
87+ // Accept uses w to write the handshake response so the timeouts on the http.Server apply.
88+ func Accept (w http.ResponseWriter , r * http.Request , opts ... AcceptOption ) (* Conn , error ) {
89+ var subprotocols []string
90+ origins := []string {r .Host }
91+ for _ , opt := range opts {
92+ switch opt := opt .(type ) {
93+ case acceptOrigins :
94+ origins = []string (opt )
95+ case acceptSubprotocols :
96+ subprotocols = []string (opt )
97+ }
98+ }
99+
100+ err := verifyClientRequest (w , r )
101+ if err != nil {
92102 return nil , err
93103 }
94104
95105 origins = append (origins , r .Host )
96106
97- err : = authenticateOrigin (r , origins )
107+ err = authenticateOrigin (r , origins )
98108 if err != nil {
99109 http .Error (w , err .Error (), http .StatusForbidden )
100110 return nil , err
@@ -112,7 +122,10 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
112122
113123 handleKey (w , r )
114124
115- selectSubprotocol (w , r , subprotocols )
125+ subproto := selectSubprotocol (r , subprotocols )
126+ if subproto != "" {
127+ w .Header ().Set ("Sec-WebSocket-Protocol" , subproto )
128+ }
116129
117130 w .WriteHeader (http .StatusSwitchingProtocols )
118131
@@ -134,16 +147,18 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
134147 return c , nil
135148}
136149
137- func selectSubprotocol (w http.ResponseWriter , r * http.Request , subprotocols []string ) {
138- clientSubprotocols := strings .Split (r .Header .Get ("Sec-WebSocket-Protocol" ), "," )
150+ func headerValuesContainsToken (h http.Header , key , val string ) bool {
151+ key = textproto .CanonicalMIMEHeaderKey (key )
152+ return httpguts .HeaderValuesContainsToken (h [key ], val )
153+ }
154+
155+ func selectSubprotocol (r * http.Request , subprotocols []string ) string {
139156 for _ , sp := range subprotocols {
140- for _ , cp := range clientSubprotocols {
141- if sp == strings .TrimSpace (cp ) {
142- w .Header ().Set ("Sec-WebSocket-Protocol" , sp )
143- return
144- }
157+ if headerValuesContainsToken (r .Header , "Sec-WebSocket-Protocol" , sp ) {
158+ return sp
145159 }
146160 }
161+ return ""
147162}
148163
149164var keyGUID = []byte ("258EAFA5-E914-47DA-95CA-C5AB0DC85B11" )
0 commit comments