Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 149 additions & 6 deletions dialtesting/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"text/template"
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
)
Expand All @@ -37,8 +38,9 @@ var (
)

const (
MaxBodySize = 10 * 1024
DefaultHTTPTimeout = 60 * time.Second
MaxBodySize = 10 * 1024
DefaultHTTPTimeout = 60 * time.Second
HTTP3HandshakeTimeout = 5 * time.Second

ProtocolAuto = "auto"
ProtocolHTTP11 = "http/1.1"
Expand Down Expand Up @@ -79,6 +81,8 @@ type HTTPTask struct {
sslCertNotBefore int64
sslCertNotAfter int64
protocol string
httpTimeout time.Duration
tlsConfig *tls.Config
}

func (t *HTTPTask) clear() {
Expand All @@ -95,13 +99,15 @@ func (t *HTTPTask) clear() {
t.reqBodyBytesBuffer = nil
t.sslCertNotBefore = 0
t.sslCertNotAfter = 0
t.destIP = ""

if t.reqBody != nil {
t.reqBody.bodyType = t.reqBody.BodyType
}
}

func (t *HTTPTask) stop() {
t.closeHTTP3Transport()
if t.cli != nil {
t.cli.CloseIdleConnections()
}
Expand Down Expand Up @@ -379,8 +385,20 @@ func (t *HTTPTask) run() error {
t.req.Header.Add("User-Agent", agentInfo)
}

if t.protocol == ProtocolHTTP3 {
t.resetHTTP3Client()
}

t.reqStart = time.Now()
t.resp, err = t.cli.Do(t.req)
if t.protocol == ProtocolHTTP3 && t.resp != nil && t1.IsZero() {
// For HTTP/3, response_ttfb is response header arrival measured by http.Client.Do return.
t1 = time.Now()
t.ttfbTime = float64(time.Since(t.reqStart)) / float64(time.Microsecond)
}
if t.protocol == ProtocolHTTP3 && t.resp != nil && t.resp.TLS != nil && t.sslCertNotAfter == 0 {
t.extractSSLCertificateValidity(*t.resp.TLS)
}
if t.resp != nil {
defer t.resp.Body.Close() //nolint:errcheck
}
Expand Down Expand Up @@ -629,14 +647,14 @@ func (t *HTTPTask) init() error {

protocol := strings.ToLower(opt.getProtocol())
t.protocol = protocol
t.httpTimeout = httpTimeout
t.tlsConfig = tlsConfig.Clone()

switch protocol {
case ProtocolHTTP3:
t.cli = &http.Client{
Timeout: httpTimeout,
Transport: &http3.RoundTripper{
TLSClientConfig: tlsConfig,
},
Timeout: httpTimeout,
Transport: t.newHTTP3RoundTripper(tlsConfig, httpTimeout),
}
case ProtocolHTTP2Only:
if isPlainHTTP(t.URL) {
Expand Down Expand Up @@ -774,6 +792,131 @@ func isPlainHTTP(rawURL string) bool {
return err == nil && strings.EqualFold(u.Scheme, "http")
}

func (t *HTTPTask) resetHTTP3Client() {
t.closeHTTP3Transport()
t.cli = &http.Client{
Timeout: t.httpTimeout,
Transport: t.newHTTP3RoundTripper(t.tlsConfig, t.httpTimeout),
}
if t.AdvanceOptions != nil && t.AdvanceOptions.RequestOptions != nil && !t.AdvanceOptions.RequestOptions.FollowRedirect {
t.cli.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
}
}

func (t *HTTPTask) closeHTTP3Transport() {
if t.protocol != ProtocolHTTP3 || t.cli == nil || t.cli.Transport == nil {
return
}
if closer, ok := t.cli.Transport.(interface{ Close() error }); ok {
_ = closer.Close()
}
}

func (t *HTTPTask) newHTTP3RoundTripper(tlsConfig *tls.Config, httpTimeout time.Duration) http.RoundTripper {
return &http3.RoundTripper{
TLSClientConfig: tlsConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

resolveStart := time.Now()
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
t.dnsParseTime = float64(time.Since(resolveStart)) / float64(time.Microsecond)
if err != nil {
return nil, err
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IP addresses found for %q", host)
}
ips = preferIPv4(ips)

dialTLSConfig := tlsCfg
if dialTLSConfig == nil {
dialTLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
} else {
dialTLSConfig = dialTLSConfig.Clone()
}
if dialTLSConfig.ServerName == "" {
dialTLSConfig.ServerName = host
}
dialQUICConfig := limitHTTP3HandshakeTimeout(cfg, httpTimeout)

var lastErr error
for _, ip := range ips {
resolvedAddr := net.JoinHostPort(ip.IP.String(), port)
connectStart := time.Now()
conn, err := quic.DialAddrEarly(ctx, resolvedAddr, dialTLSConfig, dialQUICConfig)
// For HTTP/3, response_connection is QUIC connection setup time, not TCP connect time.
t.connectionTime = float64(time.Since(connectStart)) / float64(time.Microsecond)
if err != nil {
lastErr = err
continue
}

handshakeStart := time.Now()
select {
case <-conn.HandshakeComplete():
// For HTTP/3, response_ssl is QUIC/TLS handshake time, not standalone TCP TLS handshake time.
t.sslTime = float64(time.Since(handshakeStart)) / float64(time.Microsecond)
t.destIP = ip.IP.String()
return conn, nil
case <-ctx.Done():
_ = conn.CloseWithError(quic.ApplicationErrorCode(0), ctx.Err().Error())
return nil, ctx.Err()
case <-conn.Context().Done():
lastErr = conn.Context().Err()
continue
}
}

return nil, lastErr
},
}
}

func limitHTTP3HandshakeTimeout(cfg *quic.Config, httpTimeout time.Duration) *quic.Config {
var dialQUICConfig *quic.Config
if cfg == nil {
dialQUICConfig = &quic.Config{}
} else {
dialQUICConfig = cfg.Clone()
}

handshakeTimeout := HTTP3HandshakeTimeout
if httpTimeout > 0 && httpTimeout < HTTP3HandshakeTimeout {
handshakeTimeout = httpTimeout
} else if dialQUICConfig.HandshakeIdleTimeout == 0 {
return dialQUICConfig
}
if dialQUICConfig.HandshakeIdleTimeout > handshakeTimeout {
dialQUICConfig.HandshakeIdleTimeout = handshakeTimeout
}
return dialQUICConfig
}

func preferIPv4(ips []net.IPAddr) []net.IPAddr {
if len(ips) <= 1 {
return ips
}

sorted := make([]net.IPAddr, 0, len(ips))
for _, ip := range ips {
if ip.IP.To4() != nil {
sorted = append(sorted, ip)
}
}
for _, ip := range ips {
if ip.IP.To4() == nil {
sorted = append(sorted, ip)
}
}
return sorted
}

func (t *HTTPTask) getHostName() ([]string, error) {
if hostName, err := getHostName(t.URL); err != nil {
return nil, err
Expand Down
Loading
Loading