diff --git a/protocol.go b/protocol.go index 270b90d..57ea10f 100644 --- a/protocol.go +++ b/protocol.go @@ -2,6 +2,7 @@ package proxyproto import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -51,7 +52,6 @@ type Conn struct { once sync.Once readErr error conn net.Conn - bufReader *bufio.Reader reader io.Reader header *Header ProxyHeaderPolicy Policy @@ -151,16 +151,8 @@ func (p *Listener) Addr() net.Addr { // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { - // For v1 the header length is at most 108 bytes. - // For v2 the header length is at most 52 bytes plus the length of the TLVs. - // We use 256 bytes to be safe. - const bufSize = 256 - br := bufio.NewReaderSize(conn, bufSize) - pConn := &Conn{ - bufReader: br, - reader: io.MultiReader(br, conn), - conn: conn, + conn: conn, } for _, opt := range opts { @@ -297,7 +289,25 @@ func (p *Conn) readHeader() error { } } - header, err := Read(p.bufReader) + // For v1 the header length is at most 108 bytes. + // For v2 the header length is at most 52 bytes plus the length of the TLVs. + // We use 256 bytes to be safe. + const bufSize = 256 + + bb := bytes.NewBuffer(make([]byte, 0, bufSize)) + br := bufio.NewReaderSize(io.TeeReader(p.conn, bb), bufSize) + + header, err := Read(br) + + if err == nil { + _ = bb.Next(bb.Len() - br.Buffered()) // skip header + } + + if bb.Len() == 0 { + p.reader = p.conn + } else { + p.reader = io.MultiReader(bb, p.conn) + } // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our @@ -363,27 +373,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { if p.readErr != nil { return 0, p.readErr } - - b := make([]byte, p.bufReader.Buffered()) - if _, err := p.bufReader.Read(b); err != nil { - return 0, err // this should never as we read buffered data - } - - var n int64 - { - nn, err := w.Write(b) - n += int64(nn) - if err != nil { - return n, err - } - } - { - nn, err := io.Copy(w, p.conn) - n += nn - if err != nil { - return n, err - } - } - - return n, nil + return io.Copy(w, p.reader) }