Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ type Options struct {
Verbose bool
// KillIdleConn specifies if all keep-alive connections gets killed
KillIdleConn bool
// Custom CheckRetry policy
CheckRetry CheckRetry
// Custom Backoff policy
Backoff Backoff
// Custom http client
HttpClient *http.Client
}

// DefaultOptionsSpraying contains the default options for host spraying
Expand All @@ -76,11 +82,15 @@ var DefaultOptionsSingle = Options{
KillIdleConn: false,
}

// Caller can invoke this and pass the retryPolicy and backoff.
// The function takes variadic arguments with retryPolicy and backoff.
// Anything passed outside, will still return a default setting.
func CreateClient(options Options, arguments ...interface{}) *Client {
httpclient := DefaultClient()
// NewClient creates a new Client with default settings.
func NewClient(options Options) *Client {
var httpclient *http.Client
if options.HttpClient != nil {
httpclient = options.HttpClient
} else {
httpclient = DefaultClient()
}

httpclient2 := DefaultClient()
if err := http2.ConfigureTransport(httpclient2.Transport.(*http.Transport)); err != nil {
return nil
Expand All @@ -89,24 +99,14 @@ func CreateClient(options Options, arguments ...interface{}) *Client {
var retryPolicy CheckRetry
var backoff Backoff

for _, arg := range arguments {
switch v := arg.(type) {
case CheckRetry:
retryPolicy = v
case Backoff:
backoff = v
default:
break
}
retryPolicy = DefaultRetryPolicy()
if options.CheckRetry != nil {
retryPolicy = options.CheckRetry
}

// In any case no policy or backoff found set the default config.
if retryPolicy == nil {
retryPolicy = DefaultRetryPolicy()
}

if backoff == nil {
backoff = DefaultBackoff()
backoff = DefaultBackoff()
if options.Backoff != nil {
backoff = options.Backoff
}

// if necessary adjusts per-request timeout proportionally to general timeout (30%)
Expand All @@ -126,14 +126,11 @@ func CreateClient(options Options, arguments ...interface{}) *Client {
return c
}

// NewClient creates a new Client with default settings.
func NewClient(options Options) *Client {
return CreateClient(options, DefaultRetryPolicy(), DefaultBackoff())
}

// NewWithHTTPClient creates a new Client with default settings and provided http.Client
// NewWithHTTPClient creates a new Client with custom http client
// Deprecated: Use options.HttpClient
func NewWithHTTPClient(client *http.Client, options Options) *Client {
return CreateClient(options, DefaultRetryPolicy(), DefaultBackoff())
options.HttpClient = client
return NewClient(options)
}

// setKillIdleConnections sets the kill idle conns switch in two scenarios
Expand Down
9 changes: 1 addition & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ func TestClient_Do(t *testing.T) {

// Request to /foo => 200 + valid body
func testClientSuccess_Do(t *testing.T, body interface{}) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down Expand Up @@ -180,7 +179,7 @@ func testClientSuccess_Do(t *testing.T, body interface{}) {
select {
case <-doneCh:
// client should have completed
case <-time.After(200 * time.Millisecond):
case <-time.After(time.Second):
t.Fatalf("successful request should have been completed")
case error := <-errCh:
t.Fatalf("err: %v", error)
Expand All @@ -196,7 +195,6 @@ func testClientSuccess_Do(t *testing.T, body interface{}) {
// Expected: Some recoverable network failures and after 5 retries the library should be able to get Status Code 200 + Valid Body with various backoff stategies
// Request to /successafter => 5 attempts recoverable + at 6th attempt 200 + valid body
func TestClientRetry_Do(t *testing.T) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down Expand Up @@ -227,7 +225,6 @@ func TestClientRetry_Do(t *testing.T) {
// TestClientEmptyResponse_Do tests a generic endpoint that simulates the server hanging connection immediately (http connection closed by peer)
// Expected: The library should keep on retrying until the final timeout or maximum retries amount
func TestClientEmptyResponse_Do(t *testing.T) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down Expand Up @@ -256,7 +253,6 @@ func TestClientEmptyResponse_Do(t *testing.T) {
// TestClientUnexpectedEOF_Do tests a generic endpoint that simulates the server hanging the connection in the middle of a valid response (connection failure)
// Expected: The library should keep on retrying until the final timeout or maximum retries amount
func TestClientUnexpectedEOF_Do(t *testing.T) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down Expand Up @@ -285,7 +281,6 @@ func TestClientUnexpectedEOF_Do(t *testing.T) {
// TestClientEndlessBody_Do tests a generic endpoint that simulates the server delivering an infinite content body
// Expected: The library should read until a certain limit with return code 200
func TestClientEndlessBody_Do(t *testing.T) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down Expand Up @@ -319,7 +314,6 @@ func TestClientEndlessBody_Do(t *testing.T) {
// TestClientMessyHeaders_Do tests a generic endpoint that simulates the server sending infinite headers
// Expected: The library should stop reading headers after a certain amount or go into timeout
func TestClientMessyHeaders_Do(t *testing.T) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down Expand Up @@ -352,7 +346,6 @@ func TestClientMessyHeaders_Do(t *testing.T) {
// TestClientMessyEncoding_Do tests a generic endpoint that simulates the server sending weird encodings in headers
// Expected: The library should be successful as all strings are treated as runes
func TestClientMessyEncoding_Do(t *testing.T) {

// start buggyhttp
buggyhttp.Listen(8080)
defer buggyhttp.Stop()
Expand Down
19 changes: 4 additions & 15 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package retryablehttp
import (
"context"
"crypto/x509"
"errors"
"net/http"
"net/url"
"regexp"
Expand Down Expand Up @@ -35,29 +34,20 @@ type CheckRetry func(ctx context.Context, resp *http.Response, err error) (bool,
// will retry on connection errors and server errors.
func DefaultRetryPolicy() func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return checkErrors(ctx, resp, err)
}
}

// HTTPErrorRetryPolicy is to retry for HTTPCodes >= 500.
func HTTPErrorRetryPolicy() func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return func(ctx context.Context, resp *http.Response, err error) (bool, error) {
if resp.StatusCode >= 500 {
return true, errors.New(resp.Status)
}
return checkErrors(ctx, resp, err)
return CheckRecoverableErrors(ctx, resp, err)
}
}

// HostSprayRetryPolicy provides a callback for Client.CheckRetry, which
// will retry on connection errors and server errors.
func HostSprayRetryPolicy() func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return func(ctx context.Context, resp *http.Response, err error) (bool, error) {
return checkErrors(ctx, resp, err)
return CheckRecoverableErrors(ctx, resp, err)
}
}

func checkErrors(ctx context.Context, resp *http.Response, err error) (bool, error) {
// Check recoverable errors
func CheckRecoverableErrors(ctx context.Context, resp *http.Response, err error) (bool, error) {
// do not retry on context.Canceled or context.DeadlineExceeded
if ctx.Err() != nil {
return false, ctx.Err()
Expand All @@ -80,7 +70,6 @@ func checkErrors(ctx context.Context, resp *http.Response, err error) (bool, err
return false, nil
}
}

// The error is likely recoverable so retry.
return true, nil
}
Expand Down