Skip to content

Adding sync.Pool to Decompress middleware #1699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions middleware/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}

pool := gzipPool(config)
pool := gzipCompressPool(config)

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
Expand Down Expand Up @@ -133,7 +133,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
return http.ErrNotSupported
}

func gzipPool(config GzipConfig) sync.Pool {
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
Expand Down
74 changes: 68 additions & 6 deletions middleware/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,115 @@ package middleware
import (
"bytes"
"compress/gzip"
"github.com/labstack/echo/v4"
"io"
"io/ioutil"
"net/http"
"sync"

"github.com/labstack/echo/v4"
)

type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper

// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)

//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"

// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}

var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{Skipper: DefaultSkipper}
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)

// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}

func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}

b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()

r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}

//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}

//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
gr, err := gzip.NewReader(c.Request().Body)
if err != nil {
b := c.Request().Body

i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}

if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
defer gr.Close()
var buf bytes.Buffer
io.Copy(&buf, gr)

gr.Close()
pool.Put(gr)
Copy link
Contributor

@arun0009 arun0009 Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pafuent - you are putting gr back to pool twice? On line 100 and 110? Also, whats resetting gr before putting it back in pool? of is it the usage of gr.Reset making sure that it's reset before reading new request body?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 100 because the call to Reset returned an error, so the middleware is skipped (for empty bodies) or returns the error (all the rest of the errors). And the line 110 is the happy path 😉
I'm not using a defer just to return it to the Pool as fast as is possible to avoid an allocation if a new request find the Pool empty.
Reset() requires the io.Reader, so I need to perform that after getting one instance from the Pool


b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!


r := ioutil.NopCloser(&buf)
defer r.Close()
c.Request().Body = r
}
return next(c)
Expand Down
61 changes: 61 additions & 0 deletions middleware/decompress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package middleware
import (
"bytes"
"compress/gzip"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"

"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -43,6 +45,35 @@ func TestDecompress(t *testing.T) {
assert.Equal(body, string(b))
}

func TestDecompressDefaultConfig(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test"))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil
})
h(c)

assert := assert.New(t)
assert.Equal("test", rec.Body.String())

// Decompress
body := `{"name": "echo"}`
gz, _ := gzipString(body)
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz)))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
}

func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
e := echo.New()
body := `{"name":"echo"}`
Expand Down Expand Up @@ -108,6 +139,36 @@ func TestDecompressSkipper(t *testing.T) {
assert.Equal(t, body, string(reqBody))
}

type TestDecompressPoolWithError struct {
}

func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
return errors.New("pool error")
},
}
}

func TestDecompressPoolError(t *testing.T) {
e := echo.New()
e.Use(DecompressWithConfig(DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &TestDecompressPoolWithError{},
}))
body := `{"name": "echo"}`
req := httptest.NewRequest(http.MethodPost, "/echo", strings.NewReader(body))
req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
e.ServeHTTP(rec, req)
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
reqBody, err := ioutil.ReadAll(c.Request().Body)
assert.NoError(t, err)
assert.Equal(t, body, string(reqBody))
assert.Equal(t, rec.Code, http.StatusInternalServerError)
}

func BenchmarkDecompress(b *testing.B) {
e := echo.New()
body := `{"name": "echo"}`
Expand Down