@@ -29,21 +29,26 @@ func AcceptSubprotocols(protocols ...string) AcceptOption {
2929 return acceptSubprotocols (protocols )
3030}
3131
32- type acceptOrigins [] string
32+ type acceptInsecureOrigin struct {}
3333
34- func (o acceptOrigins ) acceptOption () {}
34+ func (o acceptInsecureOrigin ) acceptOption () {}
3535
36- // AcceptOrigins lists the origins that Accept will accept.
37- // Accept will always accept r.Host as the origin. Use this
38- // option when you want to accept an origin with a different domain
39- // than the one the WebSocket server is running on.
36+ // AcceptInsecureOrigin disables Accept's origin verification
37+ // behaviour. By default Accept only allows the handshake to
38+ // succeed if the javascript that is initiating the handshake
39+ // is on the same domain as the server. This is to prevent CSRF
40+ // when secure data is stored in cookies.
4041//
41- // Use this option with caution to avoid exposing your WebSocket
42- // server to a CSRF attack.
4342// See https://stackoverflow.com/a/37837709/4283659
44- // TODO remove in favour of AcceptInsecureOrigin
45- func AcceptOrigins (origins ... string ) AcceptOption {
46- return acceptOrigins (origins )
43+ //
44+ // Use this if you want a WebSocket server any javascript can
45+ // connect to or you want to perform Origin verification yourself
46+ // and allow some whitelist of domains.
47+ //
48+ // Ensure you understand exactly what the above means before you use
49+ // this option in conjugation with cookies containing secure data.
50+ func AcceptInsecureOrigin () AcceptOption {
51+ return acceptInsecureOrigin {}
4752}
4853
4954func verifyClientRequest (w http.ResponseWriter , r * http.Request ) error {
@@ -87,11 +92,11 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
8792// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
8893func Accept (w http.ResponseWriter , r * http.Request , opts ... AcceptOption ) (* Conn , error ) {
8994 var subprotocols []string
90- origins := [] string { r . Host }
95+ verifyOrigin := true
9196 for _ , opt := range opts {
9297 switch opt := opt .(type ) {
93- case acceptOrigins :
94- origins = [] string ( opt )
98+ case acceptInsecureOrigin :
99+ verifyOrigin = false
95100 case acceptSubprotocols :
96101 subprotocols = []string (opt )
97102 }
@@ -102,12 +107,12 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
102107 return nil , err
103108 }
104109
105- origins = append ( origins , r . Host )
106-
107- err = authenticateOrigin ( r , origins )
108- if err != nil {
109- http . Error ( w , err . Error (), http . StatusForbidden )
110- return nil , err
110+ if verifyOrigin {
111+ err = authenticateOrigin ( r )
112+ if err != nil {
113+ http . Error ( w , err . Error (), http . StatusForbidden )
114+ return nil , err
115+ }
111116 }
112117
113118 hj , ok := w .(http.Hijacker )
@@ -173,7 +178,7 @@ func handleKey(w http.ResponseWriter, r *http.Request) {
173178 w .Header ().Set ("Sec-WebSocket-Accept" , responseKey )
174179}
175180
176- func authenticateOrigin (r * http.Request , origins [] string ) error {
181+ func authenticateOrigin (r * http.Request ) error {
177182 origin := r .Header .Get ("Origin" )
178183 if origin == "" {
179184 return nil
@@ -182,10 +187,8 @@ func authenticateOrigin(r *http.Request, origins []string) error {
182187 if err != nil {
183188 return xerrors .Errorf ("failed to parse Origin header %q: %w" , origin , err )
184189 }
185- for _ , o := range origins {
186- if strings .EqualFold (u .Host , o ) {
187- return nil
188- }
190+ if strings .EqualFold (u .Host , r .Host ) {
191+ return nil
189192 }
190- return xerrors .Errorf ("request origin %q is not authorized" , r . Header . Get ( "Origin" ) )
193+ return xerrors .Errorf ("request origin %q is not authorized" , origin )
191194}
0 commit comments