Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ type AcceptOptions struct {
// reject it, close the connection when c.Subprotocol() == "".
Subprotocols []string

// Protocol selects which HTTP version to accept. Zero value defaults to
// ProtocolHTTP1. ProtocolAcceptAny allows accepting either HTTP/1.1 or
// HTTPProtocol selects which HTTP version to accept. Zero value defaults to
// HTTPProtocol1. HTTPProtocolAny allows accepting either HTTP/1.1 or
// HTTP/2.
//
// Experimental: This feature is experimental and may change in the future.
Protocol Protocol
HTTPProtocol HTTPProtocol

// InsecureSkipVerify is used to disable Accept's origin verification behavior.
//
Expand Down Expand Up @@ -95,10 +95,10 @@ func (opts *AcceptOptions) cloneWithDefaults() (*AcceptOptions, error) {
}

// Defaults to HTTP/1.1 only to preserve existing behavior (zero value).
switch o.Protocol {
case ProtocolAcceptAny, ProtocolHTTP1, ProtocolHTTP2:
switch o.HTTPProtocol {
case HTTPProtocolAny, HTTPProtocol1, HTTPProtocol2:
default:
return nil, fmt.Errorf("websocket: invalid protocol for accept options: %s", o.Protocol)
return nil, fmt.Errorf("websocket: invalid protocol for accept options: %s", o.HTTPProtocol)
}

return &o, nil
Expand Down Expand Up @@ -148,7 +148,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
}

switch proto {
case ProtocolHTTP2:
case HTTPProtocol2:
// Prepare response headers for H2 (no Connection/Upgrade).
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))

Expand Down Expand Up @@ -184,7 +184,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
bw: getBufioWriter(stream),
}), nil

case ProtocolHTTP1:
case HTTPProtocol1:
hj, ok := hijacker(w)
if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
Expand Down Expand Up @@ -243,32 +243,32 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
}
}

func verifyClientRequest(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (proto Protocol, key string, errCode int, err error) {
func verifyClientRequest(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (proto HTTPProtocol, key string, errCode int, err error) {
if r.ProtoMajor == 2 {
switch opts.Protocol {
case ProtocolHTTP1:
return ProtocolHTTP2, "", http.StatusBadRequest, errors.New("HTTP/2 extended CONNECT refused: server only accepts HTTP/1.1 Upgrade")
switch opts.HTTPProtocol {
case HTTPProtocol1:
return HTTPProtocol2, "", http.StatusBadRequest, errors.New("HTTP/2 extended CONNECT refused: server only accepts HTTP/1.1 Upgrade")
}

// HTTP/2 extended CONNECT (RFC 8441) path.
key, errCode, err = verifyClientRequestH2(w, r)
if err != nil {
return ProtocolHTTP2, "", errCode, err
return HTTPProtocol2, "", errCode, err
}
return ProtocolHTTP2, key, 0, nil
return HTTPProtocol2, key, 0, nil
}

switch opts.Protocol {
case ProtocolHTTP2:
return ProtocolHTTP1, "", http.StatusBadRequest, errors.New("HTTP/1.1 Upgrade refused: server requires HTTP/2 extended CONNECT")
switch opts.HTTPProtocol {
case HTTPProtocol2:
return HTTPProtocol1, "", http.StatusBadRequest, errors.New("HTTP/1.1 Upgrade refused: server requires HTTP/2 extended CONNECT")
}

// HTTP/1.1 GET/Upgrade handshake validation.
key, errCode, err = verifyClientRequestH1(w, r)
if err != nil {
return ProtocolHTTP1, "", errCode, err
return HTTPProtocol1, "", errCode, err
}
return ProtocolHTTP1, key, 0, nil
return HTTPProtocol1, key, 0, nil
}

// verifyClientRequestH1 validates an HTTP/1.1 WebSocket GET/Upgrade request.
Expand Down
2 changes: 1 addition & 1 deletion accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func Test_verifyClientHandshake(t *testing.T) {
r.Header.Add(k, v)
}

_, _, _, err := verifyClientRequest(httptest.NewRecorder(), r, &AcceptOptions{Protocol: ProtocolHTTP1})
_, _, _, err := verifyClientRequest(httptest.NewRecorder(), r, &AcceptOptions{HTTPProtocol: HTTPProtocol1})
if tc.success {
assert.Success(t, err)
} else {
Expand Down
32 changes: 18 additions & 14 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ type DialOptions struct {
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header

// Protocol selects the HTTP version for the handshake. Zero value defaults
// to ProtocolHTTP1. ProtocolAcceptAny is not supported by Dial.
// HTTPProtocol selects the HTTP version for the handshake. Zero value defaults
// to HTTPProtocol1. HTTPProtocolAny is not supported by Dial.
//
// Experimental: This feature is experimental and may change in the future.
Protocol Protocol
HTTPProtocol HTTPProtocol

// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Expand Down Expand Up @@ -80,10 +80,10 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context
}

// Defaults to HTTP/1.1 only to preserve existing behavior (zero value).
switch o.Protocol {
case ProtocolHTTP1, ProtocolHTTP2:
switch o.HTTPProtocol {
case HTTPProtocol1, HTTPProtocol2:
default:
return nil, nil, nil, fmt.Errorf("websocket: invalid protocol for dial options: %s", o.Protocol)
return nil, nil, nil, fmt.Errorf("websocket: invalid protocol for dial options: %s", o.HTTPProtocol)
}

if o.HTTPClient == nil {
Expand Down Expand Up @@ -218,13 +218,13 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}

switch opts.Protocol {
case ProtocolHTTP2:
switch opts.HTTPProtocol {
case HTTPProtocol2:
return handshakeRequestH2(ctx, u, opts, copts, secWebSocketKey)
case ProtocolHTTP1:
case HTTPProtocol1:
return handshakeRequestH1(ctx, u, opts, copts, secWebSocketKey)
default:
return nil, fmt.Errorf("unknown protocol: %s", opts.Protocol)
return nil, fmt.Errorf("unknown protocol: %s", opts.HTTPProtocol)
}
}

Expand Down Expand Up @@ -317,17 +317,21 @@ func secWebSocketKey(rr io.Reader) (string, error) {
}

func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
switch opts.Protocol {
case ProtocolHTTP2:
switch opts.HTTPProtocol {
case HTTPProtocol2:
return verifyServerResponseH2(opts, copts, secWebSocketKey, resp)
case ProtocolHTTP1:
case HTTPProtocol1:
return verifyServerResponseH1(opts, copts, secWebSocketKey, resp)
default:
return nil, fmt.Errorf("unknown protocol: %s", opts.Protocol)
return nil, fmt.Errorf("unknown protocol: %s", opts.HTTPProtocol)
}
}

func verifyServerResponseH1(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.ProtoMajor != 1 {
return nil, fmt.Errorf("expected HTTP/1.x response but got: %s", resp.Proto)
}

if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
// Some important caveats to be aware of:
//
// - Accept always errors out
// - Protocol in DialOptions and AcceptOptions is no-op
// - HTTPProtocol in DialOptions and AcceptOptions is no-op
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
// - Conn.Ping is no-op
Expand Down
2 changes: 1 addition & 1 deletion internal/examples/http2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Examples:

conn, resp, err := websocket.Dial(ctx, rawURL, &websocket.DialOptions{
HTTPClient: hc,
Protocol: websocket.ProtocolHTTP2,
HTTPProtocol: websocket.HTTPProtocol2,
})
if err != nil {
if resp != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/examples/http2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Examples:
ctx := r.Context()

c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Protocol: websocket.ProtocolHTTP2,
HTTPProtocol: websocket.HTTPProtocol2,
})
if err != nil {
// Accept already wrote an error response.
Expand Down
52 changes: 26 additions & 26 deletions internal/thirdparty/http2/http2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ type runTableTestCase struct {
name string
scheme string // "wss" or "ws"
client func(*testing.T) *http.Client
clientProto websocket.Protocol
serverProto websocket.Protocol
clientProto websocket.HTTPProtocol
serverProto websocket.HTTPProtocol
wantProto int // 1 or 2
wantStatus int // Wanted status code (e.g., 200 or 100).
wantErr bool // Want a Dial error.
Expand Down Expand Up @@ -228,32 +228,32 @@ var sharedTestCases = []runTableTestCase{
name: "Error TLS ClientHTTP1 RequestHTTP2 AcceptAny",
scheme: "wss",
client: func(t *testing.T) *http.Client { return newH1TLSClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolAcceptAny,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocolAny,
wantErr: true,
},
{
name: "Error H2C ClientHTTP1 RequestHTTP2 AcceptAny",
scheme: "ws",
client: func(t *testing.T) *http.Client { return newH1TLSClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolAcceptAny,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocolAny,
wantErr: true,
},
{
name: "Error TLS ClientHTTP2 RequestHTTP2 AcceptHTTP1",
scheme: "wss",
client: func(t *testing.T) *http.Client { return newH2TLSClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolHTTP1,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocol1,
wantErr: true,
},
{
name: "Error TLS ClientHTTP1 RequestHTTP1 AcceptHTTP2",
scheme: "wss",
client: func(t *testing.T) *http.Client { return newH1TLSClient() },
clientProto: websocket.ProtocolHTTP1,
serverProto: websocket.ProtocolHTTP2,
clientProto: websocket.HTTPProtocol1,
serverProto: websocket.HTTPProtocol2,
wantErr: true,
},
}
Expand All @@ -267,24 +267,24 @@ func TestHTTP2Suite_XCONNECT_Enabled(t *testing.T) {
name: "OK TLS ClientHTTP2 RequestHTTP2 AcceptHTTP2",
scheme: "wss",
client: func(t *testing.T) *http.Client { return newH2TLSClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolHTTP2,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocol2,
wantProto: 2,
},
{
name: "OK TLS ClientHTTP2 RequestHTTP2 AcceptAny",
scheme: "wss",
client: func(t *testing.T) *http.Client { return newH2TLSClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolAcceptAny,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocolAny,
wantProto: 2,
},
{
name: "OK H2C ClientHTTP2 RequestHTTP2 AcceptAny",
scheme: "ws",
client: func(t *testing.T) *http.Client { return newH2CClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolAcceptAny,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocolAny,
wantProto: 2,
},
{
Expand All @@ -295,8 +295,8 @@ func TestHTTP2Suite_XCONNECT_Enabled(t *testing.T) {
req.Header.Del(":protocol")
}, nil)
},
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolHTTP2,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocol2,
wantErr: true,
},
{
Expand All @@ -311,8 +311,8 @@ func TestHTTP2Suite_XCONNECT_Enabled(t *testing.T) {
req.Method = http.MethodGet
}, nil)
},
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolHTTP2,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocol2,
wantErr: true,
},
{
Expand All @@ -326,8 +326,8 @@ func TestHTTP2Suite_XCONNECT_Enabled(t *testing.T) {
_ = resp.Body.Close()
})
},
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolHTTP2,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocol2,
wantErr: true,
},
}...))
Expand All @@ -342,16 +342,16 @@ func TestHTTP2Suite_XCONNECT_Disabled(t *testing.T) {
name: "Error TLS ClientHTTP2 RequestHTTP2 AcceptAny NoExtendedConnect",
scheme: "wss",
client: func(t *testing.T) *http.Client { return newH2TLSClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolAcceptAny,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocolAny,
wantErr: true,
},
{
name: "Error H2C ClientHTTP2 RequestHTTP2 AcceptAny NoExtendedConnect",
scheme: "ws",
client: func(t *testing.T) *http.Client { return newH2CClient() },
clientProto: websocket.ProtocolHTTP2,
serverProto: websocket.ProtocolAcceptAny,
clientProto: websocket.HTTPProtocol2,
serverProto: websocket.HTTPProtocolAny,
wantErr: true,
},
}...))
Expand Down
Loading