Skip to content

Commit 391d034

Browse files
authored
Merge pull request #36 from projectdiscovery/dev
0.0.8
2 parents cdf6d96 + f89a181 commit 391d034

18 files changed

Lines changed: 300 additions & 120 deletions

File tree

client.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ func (c *Client) DoRawWithOptions(method, url, uripath string, headers map[strin
9292
return c.do(method, url, uripath, headers, body, redirectstatus, options)
9393
}
9494

95+
func (c *Client) getConn(protocol, host string, options Options) (Conn, error) {
96+
if options.Proxy != "" {
97+
return c.dialer.DialWithProxy(protocol, host, c.Options.Proxy, c.Options.ProxyDialTimeout)
98+
}
99+
if options.Timeout < 0 {
100+
options.Timeout = 0
101+
}
102+
return c.dialer.DialTimeout(protocol, host, options.Timeout)
103+
}
104+
95105
func (c *Client) do(method, url, uripath string, headers map[string][]string, body io.Reader, redirectstatus *RedirectStatus, options Options) (*http.Response, error) {
96106
protocol := "http"
97107
if strings.HasPrefix(strings.ToLower(url), "https://") {
@@ -137,7 +147,7 @@ func (c *Client) do(method, url, uripath string, headers map[string][]string, bo
137147
protocol = "https"
138148
}
139149

140-
conn, err := c.dialer.Dial(protocol, host)
150+
conn, err := c.getConn(protocol, host, options)
141151
if err != nil {
142152
return nil, err
143153
}
@@ -148,13 +158,13 @@ func (c *Client) do(method, url, uripath string, headers map[string][]string, bo
148158

149159
// set timeout if any
150160
if options.Timeout > 0 {
151-
conn.SetDeadline(time.Now().Add(options.Timeout))
161+
_ = conn.SetDeadline(time.Now().Add(options.Timeout))
152162
}
153163

154164
if err := conn.WriteRequest(req); err != nil {
155165
return nil, err
156166
}
157-
resp, err := conn.ReadResponse()
167+
resp, err := conn.ReadResponse(options.ForceReadAllBody)
158168
if err != nil {
159169
return nil, err
160170
}

client/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ const readerBuffer = 4096
7070
// HTTP but connection pooling is expected to be handled at a higher layer.
7171
type Client interface {
7272
WriteRequest(*Request) error
73-
ReadResponse() (*Response, error)
73+
ReadResponse(forceReadAll bool) (*Response, error)
7474
}
7575

7676
// NewClient returns a Client implementation which uses rw to communicate.
@@ -122,7 +122,7 @@ func (c *client) WriteRequest(req *Request) error {
122122
}
123123

124124
// ReadResponse unmarshalls a HTTP response.
125-
func (c *client) ReadResponse() (*Response, error) {
125+
func (c *client) ReadResponse(forceReadAll bool) (*Response, error) {
126126
version, code, msg, err := c.ReadStatusLine()
127127
var headers []Header
128128
if err != nil {
@@ -148,7 +148,7 @@ func (c *client) ReadResponse() (*Response, error) {
148148
Headers: headers,
149149
Body: c.ReadBody(),
150150
}
151-
if l := resp.ContentLength(); l >= 0 {
151+
if l := resp.ContentLength(); l >= 0 && !forceReadAll {
152152
resp.Body = io.LimitReader(resp.Body, l)
153153
} else if resp.TransferEncoding() == "chunked" {
154154
resp.Body = httputil.NewChunkedReader(resp.Body)

client/status.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ type Status struct {
6161
Reason string
6262
}
6363

64-
var invalidStatus Status
6564

6665
func (s Status) String() string { return fmt.Sprintf("%d %s", s.Code, s.Reason) }
6766

client_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package rawhttp
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
"time"
8+
9+
"github.com/julienschmidt/httprouter"
10+
"github.com/projectdiscovery/stringsutil"
11+
)
12+
13+
func getTestHttpServer(timeout time.Duration) *httptest.Server {
14+
var ts *httptest.Server
15+
router := httprouter.New()
16+
router.GET("/rawhttp", httprouter.Handle(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
17+
time.Sleep(timeout)
18+
}))
19+
ts = httptest.NewServer(router)
20+
return ts
21+
}
22+
23+
// run with go test -timeout 45s -run ^TestDialDefaultTimeout$ github.com/projectdiscovery/rawhttp
24+
func TestDialDefaultTimeout(t *testing.T) {
25+
timeout := 30 * time.Second
26+
ts := getTestHttpServer(45 * time.Second)
27+
defer ts.Close()
28+
29+
startTime := time.Now()
30+
client := NewClient(DefaultOptions)
31+
_, err := client.DoRaw("GET", ts.URL, "/rawhttp", nil, nil)
32+
if !stringsutil.ContainsAny(err.Error(), "i/o timeout") || time.Now().Before(startTime.Add(timeout)) {
33+
t.Error("default timeout error")
34+
}
35+
}
36+
37+
func TestDialWithCustomTimeout(t *testing.T) {
38+
timeout := 5 * time.Second
39+
ts := getTestHttpServer(10 * time.Second)
40+
defer ts.Close()
41+
42+
startTime := time.Now()
43+
client := NewClient(DefaultOptions)
44+
options := DefaultOptions
45+
options.Timeout = timeout
46+
_, err := client.DoRawWithOptions("GET", ts.URL, "/rawhttp", nil, nil, options)
47+
if !stringsutil.ContainsAny(err.Error(), "i/o timeout") || time.Now().Before(startTime.Add(timeout)) {
48+
t.Error("custom timeout error")
49+
}
50+
}

clientpipeline/client.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,19 @@ import (
1616
const DefaultMaxConnsPerHost = 512
1717
const DefaultMaxIdleConnDuration = 10 * time.Second
1818
const DefaultMaxIdemponentCallAttempts = 5
19+
const defaultReadBufferSize = 4096
20+
const defaultWriteBufferSize = 4096
1921

2022
type DialFunc func(addr string) (net.Conn, error)
2123
type RetryIfFunc func(request *Request) bool
2224

23-
var errorChPool sync.Pool
24-
2525
var (
2626
ErrNoFreeConns = errors.New("no free connections available to host")
2727
ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " +
2828
"Make sure the server returns 'Connection: close' response header before closing the connection")
29+
// ErrGetOnly is returned when server expects only GET requests,
30+
// but some other type of request came (Server.GetOnly option is true).
31+
ErrGetOnly = errors.New("non-GET request received")
2932
)
3033

3134
type timeoutError struct {

clientpipeline/http.go

Lines changed: 0 additions & 81 deletions
This file was deleted.

clientpipeline/response.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func (resp *Response) Read(r *bufio.Reader) error {
8080
}
8181
if key == "" {
8282
// empty header values are valid, rfc 2616 s4.2.
83-
err = errors.New("invalid header")
83+
err = errors.New("invalid header") //nolint
8484
break
8585
}
8686
headers = append(headers, Header{key, value})
@@ -222,7 +222,7 @@ func (resp *Response) ReadBody(r *bufio.Reader) io.Reader {
222222
l := resp.ContentLength()
223223
if l > 0 {
224224
resp.body = make([]byte, l)
225-
io.ReadFull(r, resp.body)
225+
io.ReadFull(r, resp.body) //nolint
226226

227227
return bytes.NewReader(resp.body)
228228
}

conn.go

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,25 @@ package rawhttp
22

33
import (
44
"crypto/tls"
5+
"fmt"
56
"io"
67
"net"
8+
"net/url"
9+
"strings"
710
"sync"
811
"time"
912

1013
"github.com/projectdiscovery/rawhttp/client"
14+
"github.com/projectdiscovery/rawhttp/proxy"
1115
)
1216

1317
// Dialer can dial a remote HTTP server.
1418
type Dialer interface {
1519
// Dial dials a remote http server returning a Conn.
1620
Dial(protocol, addr string) (Conn, error)
21+
DialWithProxy(protocol, addr, proxyURL string, timeout time.Duration) (Conn, error)
22+
// Dial dials a remote http server with timeout returning a Conn.
23+
DialTimeout(protocol, addr string, timeout time.Duration) (Conn, error)
1724
}
1825

1926
type dialer struct {
@@ -22,35 +29,91 @@ type dialer struct {
2229
}
2330

2431
func (d *dialer) Dial(protocol, addr string) (Conn, error) {
32+
return d.dialTimeout(protocol, addr, 0)
33+
}
34+
35+
func (d *dialer) DialTimeout(protocol, addr string, timeout time.Duration) (Conn, error) {
36+
return d.dialTimeout(protocol, addr, timeout)
37+
}
38+
39+
func (d *dialer) dialTimeout(protocol, addr string, timeout time.Duration) (Conn, error) {
2540
d.Lock()
2641
if d.conns == nil {
2742
d.conns = make(map[string][]Conn)
2843
}
2944
if c, ok := d.conns[addr]; ok {
3045
if len(c) > 0 {
3146
conn := c[0]
32-
c[0], c = c[len(c)-1], c[:len(c)-1]
47+
c[0] = c[len(c)-1]
3348
d.Unlock()
3449
return conn, nil
3550
}
3651
}
3752
d.Unlock()
38-
c, err := clientDial(protocol, addr)
53+
c, err := clientDial(protocol, addr, timeout)
3954
return &conn{
4055
Client: client.NewClient(c),
4156
Conn: c,
4257
dialer: d,
4358
}, err
4459
}
4560

46-
func clientDial(protocol, addr string) (net.Conn, error) {
47-
// http
48-
if protocol == "http" {
49-
return net.Dial("tcp", addr)
61+
func (d *dialer) DialWithProxy(protocol, addr, proxyURL string, timeout time.Duration) (Conn, error) {
62+
var c net.Conn
63+
u, err := url.Parse(proxyURL)
64+
if err != nil {
65+
return nil, fmt.Errorf("unsupported proxy error: %w", err)
66+
}
67+
switch u.Scheme {
68+
case "http":
69+
c, err = proxy.HTTPDialer(proxyURL, timeout)(addr)
70+
case "socks5", "socks5h":
71+
c, err = proxy.Socks5Dialer(proxyURL, timeout)(addr)
72+
default:
73+
return nil, fmt.Errorf("unsupported proxy protocol: %s", proxyURL)
74+
}
75+
if err != nil {
76+
return nil, fmt.Errorf("proxy error: %w", err)
5077
}
78+
if protocol == "https" {
79+
if c, err = TlsHandshake(c, addr); err != nil {
80+
return nil, fmt.Errorf("tls handshake error: %w", err)
81+
}
82+
}
83+
return &conn{
84+
Client: client.NewClient(c),
85+
Conn: c,
86+
dialer: d,
87+
}, err
88+
}
89+
90+
func clientDial(protocol, addr string, timeout time.Duration) (net.Conn, error) {
91+
conn, err := net.DialTimeout("tcp", addr, timeout)
92+
if protocol == "https" {
93+
if conn, err = TlsHandshake(conn, addr); err != nil {
94+
return nil, fmt.Errorf("tls handshake error: %w", err)
95+
}
96+
}
97+
return conn, err
98+
}
5199

52-
// https
53-
return tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
100+
// TlsHandshake tls handshake on a plain connection
101+
func TlsHandshake(conn net.Conn, addr string) (net.Conn, error) {
102+
colonPos := strings.LastIndex(addr, ":")
103+
if colonPos == -1 {
104+
colonPos = len(addr)
105+
}
106+
hostname := addr[:colonPos]
107+
108+
tlsConn := tls.Client(conn, &tls.Config{
109+
InsecureSkipVerify: true,
110+
ServerName: hostname,
111+
})
112+
if err := tlsConn.Handshake(); err != nil {
113+
conn.Close()
114+
return nil, err
115+
}
116+
return tlsConn, nil
54117
}
55118

56119
// Conn is an interface implemented by a connection

example/server/server.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ package main
33
import (
44
"fmt"
55
"net/http"
6-
)
76

8-
var i int
7+
"github.com/projectdiscovery/gologger"
8+
)
99

1010
func headers(w http.ResponseWriter, req *http.Request) {
1111
for name, headers := range req.Header {
@@ -17,5 +17,7 @@ func headers(w http.ResponseWriter, req *http.Request) {
1717

1818
func main() {
1919
http.HandleFunc("/headers", headers)
20-
http.ListenAndServe(":10000", nil)
20+
if err := http.ListenAndServe(":10000", nil); err != nil {
21+
gologger.Fatal().Msgf("Could not listen and serve: %s\n", err)
22+
}
2123
}

0 commit comments

Comments
 (0)