diff --git a/CHANGELOG.md b/CHANGELOG.md index e4f9b874ad..6cf38dfed0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ * [ENHANCEMENT] Ruler: Improve GetRules response time by refactoring mutexes and introducing a temporary rules cache in `ruler/manager.go`. #5805 * [ENHANCEMENT] Querier: Add context error check when merging slices from ingesters for GetLabel operations. #5837 * [ENHANCEMENT] Ring: Add experimental `-ingester.tokens-generator-strategy=minimize-spread` flag to enable the new minimize spread token generator strategy. #5855 +* [ENHANCEMENT] Query Frontend: Ensure error response returned by Query Frontend follows Prometheus API error response format. #5811 * [BUGFIX] Distributor: Do not use label with empty values for sharding #5717 * [BUGFIX] Query Frontend: queries with negative offset should check whether it is cacheable or not. #5719 * [BUGFIX] Redis Cache: pass `cache_size` config correctly. #5734 diff --git a/pkg/frontend/transport/handler.go b/pkg/frontend/transport/handler.go index 72d4a1564c..6da1686562 100644 --- a/pkg/frontend/transport/handler.go +++ b/pkg/frontend/transport/handler.go @@ -19,13 +19,13 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/weaveworks/common/httpgrpc" - "github.com/weaveworks/common/httpgrpc/server" "google.golang.org/grpc/status" querier_stats "github.com/cortexproject/cortex/pkg/querier/stats" "github.com/cortexproject/cortex/pkg/querier/tripperware" "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" + util_api "github.com/cortexproject/cortex/pkg/util/api" util_log "github.com/cortexproject/cortex/pkg/util/log" ) @@ -239,8 +239,9 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeServiceTimingHeader(queryResponseTime, hs, stats) } + logger := util_log.WithContext(r.Context(), f.log) if err != nil { - writeError(w, err, hs) + writeError(logger, w, err, hs) return } @@ -252,7 +253,7 @@ func (f *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // log copy response body error so that we will know even though success response code returned bytesCopied, err := io.Copy(w, resp.Body) if err != nil && !errors.Is(err, syscall.EPIPE) { - level.Error(util_log.WithContext(r.Context(), f.log)).Log("msg", "write response body error", "bytesCopied", bytesCopied, "err", err) + level.Error(logger).Log("msg", "write response body error", "bytesCopied", bytesCopied, "err", err) } } @@ -441,7 +442,7 @@ func formatQueryString(queryString url.Values) (fields []interface{}) { return fields } -func writeError(w http.ResponseWriter, err error, additionalHeaders http.Header) { +func writeError(logger log.Logger, w http.ResponseWriter, err error, additionalHeaders http.Header) { switch err { case context.Canceled: err = errCanceled @@ -453,21 +454,13 @@ func writeError(w http.ResponseWriter, err error, additionalHeaders http.Header) } } - resp, ok := httpgrpc.HTTPResponseFromError(err) - if ok { - for k, values := range additionalHeaders { - resp.Headers = append(resp.Headers, &httpgrpc.Header{Key: k, Values: values}) - } - _ = server.WriteResponse(w, resp) - } else { - headers := w.Header() - for k, values := range additionalHeaders { - for _, value := range values { - headers.Set(k, value) - } + headers := w.Header() + for k, values := range additionalHeaders { + for _, value := range values { + headers.Set(k, value) } - http.Error(w, err.Error(), http.StatusInternalServerError) } + util_api.RespondFromGRPCError(logger, w, err) } func writeServiceTimingHeader(queryResponseTime time.Duration, headers http.Header, stats *querier_stats.QueryStats) { @@ -488,7 +481,7 @@ func statsValue(name string, d time.Duration) string { func getStatusCodeFromError(err error) int { switch err { case context.Canceled: - return StatusClientClosedRequest + return util_api.StatusClientClosedRequest case context.DeadlineExceeded: return http.StatusGatewayTimeout default: diff --git a/pkg/frontend/transport/handler_test.go b/pkg/frontend/transport/handler_test.go index 4b84ebe9c4..d3e97d367b 100644 --- a/pkg/frontend/transport/handler_test.go +++ b/pkg/frontend/transport/handler_test.go @@ -3,6 +3,7 @@ package transport import ( "bytes" "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -13,14 +14,18 @@ import ( "github.com/go-kit/log" "github.com/pkg/errors" + v1 "github.com/prometheus/client_golang/api/prometheus/v1" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/weaveworks/common/httpgrpc" "github.com/weaveworks/common/user" + "google.golang.org/grpc/codes" querier_stats "github.com/cortexproject/cortex/pkg/querier/stats" + util_api "github.com/cortexproject/cortex/pkg/util/api" + util_log "github.com/cortexproject/cortex/pkg/util/log" ) type roundTripperFunc func(*http.Request) (*http.Response, error) @@ -34,19 +39,111 @@ func TestWriteError(t *testing.T) { status int err error additionalHeaders http.Header + expectedErrResp util_api.Response }{ - {http.StatusInternalServerError, errors.New("unknown"), http.Header{"User-Agent": []string{"Golang"}}}, - {http.StatusInternalServerError, errors.New("unknown"), nil}, - {http.StatusGatewayTimeout, context.DeadlineExceeded, nil}, - {StatusClientClosedRequest, context.Canceled, nil}, - {StatusClientClosedRequest, context.Canceled, http.Header{"User-Agent": []string{"Golang"}}}, - {StatusClientClosedRequest, context.Canceled, http.Header{"User-Agent": []string{"Golang"}, "Content-Type": []string{"application/json"}}}, - {http.StatusBadRequest, httpgrpc.Errorf(http.StatusBadRequest, ""), http.Header{}}, - {http.StatusRequestEntityTooLarge, errors.New("http: request body too large"), http.Header{}}, + { + http.StatusInternalServerError, + errors.New("unknown"), + http.Header{"User-Agent": []string{"Golang"}}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrServer, + Error: "unknown", + }, + }, + { + http.StatusInternalServerError, + errors.New("unknown"), + nil, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrServer, + Error: "unknown", + }, + }, + { + http.StatusGatewayTimeout, + context.DeadlineExceeded, + nil, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrTimeout, + Error: "", + }, + }, + { + StatusClientClosedRequest, + context.Canceled, + nil, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrCanceled, + Error: "", + }, + }, + { + StatusClientClosedRequest, + context.Canceled, + http.Header{"User-Agent": []string{"Golang"}}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrCanceled, + Error: "", + }, + }, + { + StatusClientClosedRequest, + context.Canceled, + http.Header{"User-Agent": []string{"Golang"}, "Content-Type": []string{"application/json"}}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrCanceled, + Error: "", + }, + }, + {http.StatusBadRequest, + httpgrpc.Errorf(http.StatusBadRequest, ""), + http.Header{}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "", + }, + }, + { + http.StatusRequestEntityTooLarge, + errors.New("http: request body too large"), + http.Header{}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "http: request body too large", + }, + }, + { + http.StatusUnprocessableEntity, + httpgrpc.Errorf(http.StatusUnprocessableEntity, "limit hit"), + http.Header{}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrExec, + Error: "limit hit", + }, + }, + { + http.StatusUnprocessableEntity, + httpgrpc.Errorf(int(codes.PermissionDenied), "permission denied"), + http.Header{}, + util_api.Response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "permission denied", + }, + }, } { t.Run(test.err.Error(), func(t *testing.T) { w := httptest.NewRecorder() - writeError(w, test.err, test.additionalHeaders) + writeError(util_log.Logger, w, test.err, test.additionalHeaders) require.Equal(t, test.status, w.Result().StatusCode) expectedAdditionalHeaders := test.additionalHeaders if expectedAdditionalHeaders != nil { @@ -56,6 +153,18 @@ func TestWriteError(t *testing.T) { } } } + data, err := io.ReadAll(w.Result().Body) + require.NoError(t, err) + var res util_api.Response + err = json.Unmarshal(data, &res) + require.NoError(t, err) + resp, ok := httpgrpc.HTTPResponseFromError(test.err) + if ok { + require.Equal(t, string(resp.Body), res.Error) + } else { + require.Equal(t, test.err.Error(), res.Error) + + } }) } } diff --git a/pkg/querier/tripperware/instantquery/limits.go b/pkg/querier/tripperware/instantquery/limits.go index 64c5f4e443..b92157a500 100644 --- a/pkg/querier/tripperware/instantquery/limits.go +++ b/pkg/querier/tripperware/instantquery/limits.go @@ -48,8 +48,7 @@ func (l limitsMiddleware) Do(ctx context.Context, r tripperware.Request) (trippe if maxQueryLength := validation.SmallestPositiveNonZeroDurationPerTenant(tenantIDs, l.MaxQueryLength); maxQueryLength > 0 { expr, err := parser.ParseExpr(r.GetQuery()) if err != nil { - // Let Querier propagates the parsing error. - return l.next.Do(ctx, r) + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } // Enforce query length across all selectors in the query. diff --git a/pkg/querier/tripperware/instantquery/limits_test.go b/pkg/querier/tripperware/instantquery/limits_test.go index 1900831388..4cce781a1f 100644 --- a/pkg/querier/tripperware/instantquery/limits_test.go +++ b/pkg/querier/tripperware/instantquery/limits_test.go @@ -2,12 +2,15 @@ package instantquery import ( "context" + "net/http" "testing" "time" + "github.com/prometheus/prometheus/promql/parser" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/weaveworks/common/httpgrpc" "github.com/weaveworks/common/user" "github.com/cortexproject/cortex/pkg/querier/tripperware" @@ -20,6 +23,9 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) { thirtyDays = 30 * 24 * time.Hour ) + wrongQuery := `up[` + _, parserErr := parser.ParseExpr(wrongQuery) + tests := map[string]struct { maxQueryLength time.Duration query string @@ -31,6 +37,7 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) { "even though failed to parse expression, should return no error since request will pass to next middleware": { query: `up[`, maxQueryLength: thirtyDays, + expectedErr: httpgrpc.Errorf(http.StatusBadRequest, parserErr.Error()).Error(), }, "should succeed on a query not exceeding time range": { query: `up`, diff --git a/pkg/querier/tripperware/queryrange/limits.go b/pkg/querier/tripperware/queryrange/limits.go index 5931501cd6..49249e8b4d 100644 --- a/pkg/querier/tripperware/queryrange/limits.go +++ b/pkg/querier/tripperware/queryrange/limits.go @@ -84,8 +84,7 @@ func (l limitsMiddleware) Do(ctx context.Context, r tripperware.Request) (trippe expr, err := parser.ParseExpr(r.GetQuery()) if err != nil { - // Let Querier propagates the parsing error. - return l.next.Do(ctx, r) + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } // Enforce query length across all selectors in the query. diff --git a/pkg/querier/tripperware/queryrange/limits_test.go b/pkg/querier/tripperware/queryrange/limits_test.go index 6d01883186..de1807b97e 100644 --- a/pkg/querier/tripperware/queryrange/limits_test.go +++ b/pkg/querier/tripperware/queryrange/limits_test.go @@ -2,12 +2,15 @@ package queryrange import ( "context" + "net/http" "testing" "time" + "github.com/prometheus/prometheus/promql/parser" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/weaveworks/common/httpgrpc" "github.com/weaveworks/common/user" "github.com/cortexproject/cortex/pkg/querier/tripperware" @@ -115,6 +118,9 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) { now := time.Now() + wrongQuery := `up[` + _, parserErr := parser.ParseExpr(wrongQuery) + tests := map[string]struct { maxQueryLength time.Duration query string @@ -132,6 +138,7 @@ func TestLimitsMiddleware_MaxQueryLength(t *testing.T) { reqStartTime: now.Add(-time.Hour), reqEndTime: now, maxQueryLength: thirtyDays, + expectedErr: httpgrpc.Errorf(http.StatusBadRequest, parserErr.Error()).Error(), }, "should succeed on a query on short time range, ending now": { maxQueryLength: thirtyDays, diff --git a/pkg/querier/tripperware/queryrange/split_by_interval.go b/pkg/querier/tripperware/queryrange/split_by_interval.go index 2717fa415e..7ff38f4ec3 100644 --- a/pkg/querier/tripperware/queryrange/split_by_interval.go +++ b/pkg/querier/tripperware/queryrange/split_by_interval.go @@ -47,12 +47,7 @@ func (s splitByInterval) Do(ctx context.Context, r tripperware.Request) (tripper // to line up the boundaries with step. reqs, err := splitQuery(r, s.interval(r)) if err != nil { - // If the query itself is bad, we don't return error but send the query - // to querier to return the expected error message. This is not very efficient - // but should be okay for now. - // TODO(yeya24): query frontend can reuse the Prometheus API handler and return - // expected error message locally without passing it to the querier through network. - return s.next.Do(ctx, r) + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } s.splitByCounter.Add(float64(len(reqs))) diff --git a/pkg/querier/tripperware/roundtrip.go b/pkg/querier/tripperware/roundtrip.go index 0c48aac026..fc2501b303 100644 --- a/pkg/querier/tripperware/roundtrip.go +++ b/pkg/querier/tripperware/roundtrip.go @@ -142,7 +142,7 @@ func NewQueryTripperware( tenantIDs, err := tenant.TenantIDs(r.Context()) // This should never happen anyways because we have auth middleware before this. if err != nil { - return nil, err + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } now := time.Now() userStr := tenant.JoinTenantIDs(tenantIDs) @@ -161,8 +161,7 @@ func NewQueryTripperware( expr, err := parser.ParseExpr(query) if err != nil { - // If query is invalid, no need to go through tripperwares for further splitting. - return next.RoundTrip(r) + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } reqStats := stats.FromContext(r.Context()) diff --git a/pkg/querier/tripperware/shard_by.go b/pkg/querier/tripperware/shard_by.go index 39bef61ca9..7336bd21ed 100644 --- a/pkg/querier/tripperware/shard_by.go +++ b/pkg/querier/tripperware/shard_by.go @@ -55,6 +55,7 @@ func (s shardBy) Do(ctx context.Context, r Request) (Response, error) { analysis, err := s.analyzer.Analyze(r.GetQuery()) if err != nil { level.Warn(logger).Log("msg", "error analyzing query", "q", r.GetQuery(), "err", err) + return nil, httpgrpc.Errorf(http.StatusBadRequest, err.Error()) } stats.AddExtraFields( @@ -63,7 +64,7 @@ func (s shardBy) Do(ctx context.Context, r Request) (Response, error) { "shard_by.sharding_labels", analysis.ShardingLabels(), ) - if err != nil || !analysis.IsShardable() { + if !analysis.IsShardable() { return s.next.Do(ctx, r) } diff --git a/pkg/querier/tripperware/test_shard_by_query_utils.go b/pkg/querier/tripperware/test_shard_by_query_utils.go index 5cbad93ca8..e8c9370992 100644 --- a/pkg/querier/tripperware/test_shard_by_query_utils.go +++ b/pkg/querier/tripperware/test_shard_by_query_utils.go @@ -31,6 +31,7 @@ func TestQueryShardQuery(t *testing.T, instantQueryCodec Codec, shardedPrometheu name string expression string shardingLabels []string + expectedErr bool } nonShardable := []queries{ @@ -66,8 +67,9 @@ func TestQueryShardQuery(t *testing.T, instantQueryCodec Codec, shardedPrometheu http_requests_total`, }, { - name: "problematic query", - expression: `sum(a by(lanel)`, + name: "problematic query", + expression: `sum(a by(lanel)`, + expectedErr: true, }, { name: "aggregate by expression with label_replace, sharding label is dynamic", @@ -289,6 +291,7 @@ http_requests_total`, responses []string response string shardingLabels []string + expectedErr bool } tests := []testCase{ { @@ -339,7 +342,8 @@ http_requests_total`, responses: []string{ `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"foo"},"value":[1,"1"]}],"stats":{"samples":{"totalQueryableSamples":10,"totalQueryableSamplesPerStep":[[1,10]]}}}}`, }, - response: `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"foo"},"value":[1,"1"]}],"stats":{"samples":{"totalQueryableSamples":10,"totalQueryableSamplesPerStep":[[1,10]]}}}}`, + response: `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"foo"},"value":[1,"1"]}],"stats":{"samples":{"totalQueryableSamples":10,"totalQueryableSamplesPerStep":[[1,10]]}}}}`, + expectedErr: query.expectedErr, }) tests = append(tests, testCase{ name: fmt.Sprintf("non shardable query_range: %s", query.name), @@ -350,7 +354,8 @@ http_requests_total`, responses: []string{ `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__job__":"a","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":6,"totalQueryableSamplesPerStep":[[1,1],[2,2],[3,3]]}}}}`, }, - response: `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__job__":"a","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":6,"totalQueryableSamplesPerStep":[[1,1],[2,2],[3,3]]}}}}`, + response: `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__job__":"a","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":6,"totalQueryableSamplesPerStep":[[1,1],[2,2],[3,3]]}}}}`, + expectedErr: query.expectedErr, }) } @@ -366,7 +371,8 @@ http_requests_total`, `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"foo"},"value":[1,"1"]}],"stats":{"samples":{"totalQueryableSamples":10,"totalQueryableSamplesPerStep":[[1,10]]}}}}`, `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"bar"},"value":[2,"2"]}],"stats":{"samples":{"totalQueryableSamples":10,"totalQueryableSamplesPerStep":[[1,10]]}}}}`, }, - response: `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"bar"},"value":[2,"2"]},{"metric":{"__name__":"up","job":"foo"},"value":[1,"1"]}],"stats":{"samples":{"totalQueryableSamples":20,"totalQueryableSamplesPerStep":[[1,20]]}}}}`, + response: `{"status":"success","data":{"resultType":"vector","result":[{"metric":{"__name__":"up","job":"bar"},"value":[2,"2"]},{"metric":{"__name__":"up","job":"foo"},"value":[1,"1"]}],"stats":{"samples":{"totalQueryableSamples":20,"totalQueryableSamplesPerStep":[[1,20]]}}}}`, + expectedErr: query.expectedErr, }) tests = append(tests, testCase{ name: fmt.Sprintf("shardable query_range: %s", query.name), @@ -379,7 +385,8 @@ http_requests_total`, `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__name__":"metric","__job__":"a"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":6,"totalQueryableSamplesPerStep":[[1,1],[2,2],[3,3]]}}}}`, `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__name__":"metric","__job__":"b"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":6,"totalQueryableSamplesPerStep":[[1,1],[2,2],[3,3]]}}}}`, }, - response: `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__job__":"a","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]},{"metric":{"__job__":"b","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":12,"totalQueryableSamplesPerStep":[[1,2],[2,4],[3,6]]}}}}`, + response: `{"status":"success","data":{"resultType":"matrix","result":[{"metric":{"__job__":"a","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]},{"metric":{"__job__":"b","__name__":"metric"},"values":[[1,"1"],[2,"2"],[3,"3"]]}],"stats":{"samples":{"totalQueryableSamples":12,"totalQueryableSamplesPerStep":[[1,2],[2,4],[3,6]]}}}}`, + expectedErr: query.expectedErr, }) } @@ -448,16 +455,19 @@ http_requests_total`, req, err := http.NewRequest("GET", tt.path, http.NoBody) req = req.WithContext(ctx) - - require.NoError(t, err) - resp, err := roundtripper.RoundTrip(req) - require.NoError(t, err) - require.NotNil(t, resp) - contents, err := io.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, tt.response, string(contents)) + resp, err := roundtripper.RoundTrip(req) + if tt.expectedErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, resp) + + contents, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, tt.response, string(contents)) + } }) } } diff --git a/pkg/ruler/api.go b/pkg/ruler/api.go index bcb3df70dd..2e64d290e3 100644 --- a/pkg/ruler/api.go +++ b/pkg/ruler/api.go @@ -25,6 +25,7 @@ import ( "github.com/cortexproject/cortex/pkg/ruler/rulespb" "github.com/cortexproject/cortex/pkg/ruler/rulestore" "github.com/cortexproject/cortex/pkg/tenant" + util_api "github.com/cortexproject/cortex/pkg/util/api" util_log "github.com/cortexproject/cortex/pkg/util/log" ) @@ -32,13 +33,6 @@ import ( // This is required because the prometheus api implementation does not allow us to return errors // on rule lookups, which might fail in Cortex's case. -type response struct { - Status string `json:"status"` - Data interface{} `json:"data"` - ErrorType v1.ErrorType `json:"errorType"` - Error string `json:"error"` -} - // AlertDiscovery has info for all active alerts. type AlertDiscovery struct { Alerts []*Alert `json:"alerts"` @@ -103,46 +97,6 @@ type recordingRule struct { EvaluationTime float64 `json:"evaluationTime"` } -func respondError(logger log.Logger, w http.ResponseWriter, msg string) { - b, err := json.Marshal(&response{ - Status: "error", - ErrorType: v1.ErrServer, - Error: msg, - Data: nil, - }) - - if err != nil { - level.Error(logger).Log("msg", "error marshaling json response", "err", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusInternalServerError) - if n, err := w.Write(b); err != nil { - level.Error(logger).Log("msg", "error writing response", "bytesWritten", n, "err", err) - } -} - -func respondBadRequest(logger log.Logger, w http.ResponseWriter, msg string) { - b, err := json.Marshal(&response{ - Status: "error", - ErrorType: v1.ErrBadData, - Error: msg, - Data: nil, - }) - - if err != nil { - level.Error(logger).Log("msg", "error marshaling json response", "err", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusBadRequest) - if n, err := w.Write(b); err != nil { - level.Error(logger).Log("msg", "error writing response", "bytesWritten", n, "err", err) - } -} - // API is used to handle HTTP requests for the ruler service type API struct { ruler *Ruler @@ -165,19 +119,19 @@ func (a *API) PrometheusRules(w http.ResponseWriter, req *http.Request) { userID, err := tenant.TenantID(req.Context()) if err != nil || userID == "" { level.Error(logger).Log("msg", "error extracting org id from context", "err", err) - respondError(logger, w, "no valid org id found") + util_api.RespondError(logger, w, v1.ErrBadData, "no valid org id found", http.StatusBadRequest) return } if err := req.ParseForm(); err != nil { level.Error(logger).Log("msg", "error parsing form/query params", "err", err) - respondBadRequest(logger, w, "error parsing form/query params") + util_api.RespondError(logger, w, v1.ErrBadData, "error parsing form/query params", http.StatusBadRequest) return } typ := strings.ToLower(req.URL.Query().Get("type")) if typ != "" && typ != alertingRuleFilter && typ != recordingRuleFilter { - respondBadRequest(logger, w, fmt.Sprintf("unsupported rule type %q", typ)) + util_api.RespondError(logger, w, v1.ErrBadData, fmt.Sprintf("unsupported rule type %q", typ), http.StatusBadRequest) return } @@ -192,7 +146,7 @@ func (a *API) PrometheusRules(w http.ResponseWriter, req *http.Request) { rgs, err := a.ruler.GetRules(req.Context(), rulesRequest) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError) return } @@ -261,13 +215,13 @@ func (a *API) PrometheusRules(w http.ResponseWriter, req *http.Request) { return groups[i].File < groups[j].File }) - b, err := json.Marshal(&response{ + b, err := json.Marshal(&util_api.Response{ Status: "success", Data: &RuleDiscovery{RuleGroups: groups}, }) if err != nil { level.Error(logger).Log("msg", "error marshaling json response", "err", err) - respondError(logger, w, "unable to marshal the requested data") + util_api.RespondError(logger, w, v1.ErrServer, "unable to marshal the requested data", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") @@ -282,7 +236,7 @@ func (a *API) PrometheusAlerts(w http.ResponseWriter, req *http.Request) { userID, err := tenant.TenantID(req.Context()) if err != nil || userID == "" { level.Error(logger).Log("msg", "error extracting org id from context", "err", err) - respondError(logger, w, "no valid org id found") + util_api.RespondError(logger, w, v1.ErrBadData, "no valid org id found", http.StatusBadRequest) return } @@ -293,7 +247,7 @@ func (a *API) PrometheusAlerts(w http.ResponseWriter, req *http.Request) { rgs, err := a.ruler.GetRules(req.Context(), rulesRequest) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError) return } @@ -319,13 +273,13 @@ func (a *API) PrometheusAlerts(w http.ResponseWriter, req *http.Request) { } } - b, err := json.Marshal(&response{ + b, err := json.Marshal(&util_api.Response{ Status: "success", Data: &AlertDiscovery{Alerts: alerts}, }) if err != nil { level.Error(logger).Log("msg", "error marshaling json response", "err", err) - respondError(logger, w, "unable to marshal the requested data") + util_api.RespondError(logger, w, v1.ErrServer, "unable to marshal the requested data", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") @@ -362,12 +316,12 @@ func marshalAndSend(output interface{}, w http.ResponseWriter, logger log.Logger } func respondAccepted(w http.ResponseWriter, logger log.Logger) { - b, err := json.Marshal(&response{ + b, err := json.Marshal(&util_api.Response{ Status: "success", }) if err != nil { level.Error(logger).Log("msg", "error marshaling json response", "err", err) - respondError(logger, w, "unable to marshal the requested data") + util_api.RespondError(logger, w, v1.ErrServer, "unable to marshal the requested data", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") @@ -444,7 +398,7 @@ func (a *API) ListRules(w http.ResponseWriter, req *http.Request) { userID, namespace, _, err := parseRequest(req, false, false) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrBadData, err.Error(), http.StatusBadRequest) return } @@ -477,7 +431,7 @@ func (a *API) GetRuleGroup(w http.ResponseWriter, req *http.Request) { logger := util_log.WithContext(req.Context(), a.logger) userID, namespace, groupName, err := parseRequest(req, true, true) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrBadData, err.Error(), http.StatusBadRequest) return } @@ -499,7 +453,7 @@ func (a *API) CreateRuleGroup(w http.ResponseWriter, req *http.Request) { logger := util_log.WithContext(req.Context(), a.logger) userID, namespace, _, err := parseRequest(req, true, false) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrBadData, err.Error(), http.StatusBadRequest) return } @@ -581,7 +535,7 @@ func (a *API) DeleteNamespace(w http.ResponseWriter, req *http.Request) { userID, namespace, _, err := parseRequest(req, true, false) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrBadData, err.Error(), http.StatusBadRequest) return } @@ -591,7 +545,7 @@ func (a *API) DeleteNamespace(w http.ResponseWriter, req *http.Request) { http.Error(w, err.Error(), http.StatusNotFound) return } - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError) return } @@ -603,7 +557,7 @@ func (a *API) DeleteRuleGroup(w http.ResponseWriter, req *http.Request) { userID, namespace, groupName, err := parseRequest(req, true, true) if err != nil { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrBadData, err.Error(), http.StatusBadRequest) return } @@ -613,7 +567,7 @@ func (a *API) DeleteRuleGroup(w http.ResponseWriter, req *http.Request) { http.Error(w, err.Error(), http.StatusNotFound) return } - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError) return } diff --git a/pkg/ruler/api_test.go b/pkg/ruler/api_test.go index 0edc67450c..6cbe5b594a 100644 --- a/pkg/ruler/api_test.go +++ b/pkg/ruler/api_test.go @@ -16,6 +16,7 @@ import ( "github.com/weaveworks/common/user" "github.com/cortexproject/cortex/pkg/ruler/rulespb" + util_api "github.com/cortexproject/cortex/pkg/util/api" "github.com/cortexproject/cortex/pkg/util/services" ) @@ -36,14 +37,14 @@ func TestRuler_rules(t *testing.T) { body, _ := io.ReadAll(resp.Body) // Check status code and status response - responseJSON := response{} + responseJSON := util_api.Response{} err := json.Unmarshal(body, &responseJSON) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseJSON.Status, "success") // Testing the running rules for user1 in the mock store - expectedResponse, _ := json.Marshal(response{ + expectedResponse, _ := json.Marshal(util_api.Response{ Status: "success", Data: &RuleDiscovery{ RuleGroups: []*RuleGroup{ @@ -92,14 +93,14 @@ func TestRuler_rules_special_characters(t *testing.T) { body, _ := io.ReadAll(resp.Body) // Check status code and status response - responseJSON := response{} + responseJSON := util_api.Response{} err := json.Unmarshal(body, &responseJSON) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseJSON.Status, "success") // Testing the running rules for user1 in the mock store - expectedResponse, _ := json.Marshal(response{ + expectedResponse, _ := json.Marshal(util_api.Response{ Status: "success", Data: &RuleDiscovery{ RuleGroups: []*RuleGroup{ @@ -146,14 +147,14 @@ func TestRuler_rules_limit(t *testing.T) { resp := w.Result() body, _ := io.ReadAll(resp.Body) // Check status code and status response - responseJSON := response{} + responseJSON := util_api.Response{} err := json.Unmarshal(body, &responseJSON) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, responseJSON.Status, "success") // Testing the running rules for user1 in the mock store - expectedResponse, _ := json.Marshal(response{ + expectedResponse, _ := json.Marshal(util_api.Response{ Status: "success", Data: &RuleDiscovery{ RuleGroups: []*RuleGroup{ @@ -201,7 +202,7 @@ func TestRuler_alerts(t *testing.T) { body, _ := io.ReadAll(resp.Body) // Check status code and status response - responseJSON := response{} + responseJSON := util_api.Response{} err := json.Unmarshal(body, &responseJSON) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -209,7 +210,7 @@ func TestRuler_alerts(t *testing.T) { // Currently there is not an easy way to mock firing alerts. The empty // response case is tested instead. - expectedResponse, _ := json.Marshal(response{ + expectedResponse, _ := json.Marshal(util_api.Response{ Status: "success", Data: &AlertDiscovery{ Alerts: []*Alert{}, diff --git a/pkg/ruler/ruler.go b/pkg/ruler/ruler.go index ba3c698e36..66531cdc57 100644 --- a/pkg/ruler/ruler.go +++ b/pkg/ruler/ruler.go @@ -15,6 +15,7 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/pkg/errors" + v1 "github.com/prometheus/client_golang/api/prometheus/v1" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/prometheus/model/labels" @@ -32,6 +33,7 @@ import ( "github.com/cortexproject/cortex/pkg/ruler/rulestore" "github.com/cortexproject/cortex/pkg/tenant" "github.com/cortexproject/cortex/pkg/util" + util_api "github.com/cortexproject/cortex/pkg/util/api" "github.com/cortexproject/cortex/pkg/util/concurrency" "github.com/cortexproject/cortex/pkg/util/flagext" "github.com/cortexproject/cortex/pkg/util/grpcclient" @@ -1002,7 +1004,7 @@ func (r *Ruler) DeleteTenantConfiguration(w http.ResponseWriter, req *http.Reque err = r.store.DeleteNamespace(req.Context(), userID, "") // Empty namespace = delete all rule groups. if err != nil && !errors.Is(err, rulestore.ErrGroupNamespaceNotFound) { - respondError(logger, w, err.Error()) + util_api.RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError) return } diff --git a/pkg/util/api/response.go b/pkg/util/api/response.go new file mode 100644 index 0000000000..16e016bd5b --- /dev/null +++ b/pkg/util/api/response.go @@ -0,0 +1,86 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + v1 "github.com/prometheus/client_golang/api/prometheus/v1" + "github.com/weaveworks/common/httpgrpc" + "google.golang.org/grpc/codes" +) + +const ( + // StatusClientClosedRequest is the status code for when a client request cancellation of a http request + StatusClientClosedRequest = 499 +) + +// Response defines the Prometheus response format. +type Response struct { + Status string `json:"status"` + Data interface{} `json:"data"` + ErrorType v1.ErrorType `json:"errorType"` + Error string `json:"error"` + Warnings []string `json:"warnings,omitempty"` +} + +// RespondFromGRPCError writes gRPC error in Prometheus response format. +// If error is not a valid gRPC error, use server_error instead. +func RespondFromGRPCError(logger log.Logger, w http.ResponseWriter, err error) { + resp, ok := httpgrpc.HTTPResponseFromError(err) + if ok { + code := int(resp.Code) + var errTyp v1.ErrorType + switch resp.Code { + case http.StatusBadRequest, http.StatusRequestEntityTooLarge: + errTyp = v1.ErrBadData + case StatusClientClosedRequest: + errTyp = v1.ErrCanceled + case http.StatusGatewayTimeout: + errTyp = v1.ErrTimeout + case http.StatusUnprocessableEntity: + errTyp = v1.ErrExec + case int32(codes.PermissionDenied): + // Convert gRPC status code to HTTP status code. + code = http.StatusUnprocessableEntity + errTyp = v1.ErrBadData + default: + errTyp = v1.ErrServer + } + RespondError(logger, w, errTyp, string(resp.Body), code) + } else { + RespondError(logger, w, v1.ErrServer, err.Error(), http.StatusInternalServerError) + } +} + +// RespondError writes error in Prometheus response format using provided error type and message. +func RespondError(logger log.Logger, w http.ResponseWriter, errorType v1.ErrorType, msg string, statusCode int) { + var ( + res Response + b []byte + err error + ) + b = []byte(msg) + // Try to deserialize response and see if it is already in Prometheus error format. + if err = json.Unmarshal(b, &res); err != nil { + b, err = json.Marshal(&Response{ + Status: "error", + ErrorType: errorType, + Error: msg, + Data: nil, + }) + } + + if err != nil { + level.Error(logger).Log("msg", "error marshaling json response", "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if n, err := w.Write(b); err != nil { + level.Error(logger).Log("msg", "error writing response", "bytesWritten", n, "err", err) + } +} diff --git a/pkg/util/api/response_test.go b/pkg/util/api/response_test.go new file mode 100644 index 0000000000..36250a1e12 --- /dev/null +++ b/pkg/util/api/response_test.go @@ -0,0 +1,212 @@ +package api + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-kit/log" + v1 "github.com/prometheus/client_golang/api/prometheus/v1" + "github.com/stretchr/testify/require" + "github.com/weaveworks/common/httpgrpc" + "google.golang.org/grpc/codes" +) + +func TestRespondFromGRPCError(t *testing.T) { + logger := log.NewNopLogger() + for _, tc := range []struct { + name string + err error + expectedResp *Response + code int + }{ + { + name: "non grpc error", + err: errors.New("test"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrServer, + Error: "test", + }, + code: 500, + }, + { + name: "bad data", + err: httpgrpc.Errorf(http.StatusBadRequest, "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "bad_data", + }, + code: http.StatusBadRequest, + }, + { + name: "413", + err: httpgrpc.Errorf(http.StatusRequestEntityTooLarge, "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "bad_data", + }, + code: http.StatusRequestEntityTooLarge, + }, + { + name: "499", + err: httpgrpc.Errorf(StatusClientClosedRequest, "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrCanceled, + Error: "bad_data", + }, + code: StatusClientClosedRequest, + }, + { + name: "504", + err: httpgrpc.Errorf(http.StatusGatewayTimeout, "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrTimeout, + Error: "bad_data", + }, + code: http.StatusGatewayTimeout, + }, + { + name: "422", + err: httpgrpc.Errorf(http.StatusUnprocessableEntity, "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrExec, + Error: "bad_data", + }, + code: http.StatusUnprocessableEntity, + }, + { + name: "grpc status code", + err: httpgrpc.Errorf(int(codes.PermissionDenied), "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrBadData, + Error: "bad_data", + }, + code: http.StatusUnprocessableEntity, + }, + { + name: "other status code defaults to err server", + err: httpgrpc.Errorf(http.StatusTooManyRequests, "bad_data"), + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrServer, + Error: "bad_data", + }, + code: http.StatusTooManyRequests, + }, + } { + t.Run(tc.name, func(t *testing.T) { + writer := httptest.NewRecorder() + RespondFromGRPCError(logger, writer, tc.err) + output, err := io.ReadAll(writer.Body) + require.NoError(t, err) + var res Response + err = json.Unmarshal(output, &res) + require.NoError(t, err) + + require.Equal(t, tc.expectedResp.Status, res.Status) + require.Equal(t, tc.expectedResp.Error, res.Error) + require.Equal(t, tc.expectedResp.ErrorType, res.ErrorType) + + require.Equal(t, tc.code, writer.Code) + }) + } +} + +func TestRespondError(t *testing.T) { + logger := log.NewNopLogger() + for _, tc := range []struct { + name string + errorType v1.ErrorType + msg string + status string + code int + expectedResp *Response + }{ + { + name: "bad data", + errorType: v1.ErrBadData, + msg: "test_msg", + status: "error", + code: 400, + }, + { + name: "server error", + errorType: v1.ErrServer, + msg: "test_msg", + status: "error", + code: 500, + }, + { + name: "canceled", + errorType: v1.ErrCanceled, + msg: "test_msg", + status: "error", + code: 499, + }, + { + name: "timeout", + errorType: v1.ErrTimeout, + msg: "test_msg", + status: "error", + code: 502, + }, + { + name: "prometheus_format_error", + expectedResp: &Response{ + Status: "error", + ErrorType: v1.ErrServer, + Error: "server_error", + }, + code: 400, + status: "error", + errorType: v1.ErrBadData, + }, + { + // If the input Prometheus error cannot be unmarshalled, + // use the error type and message provided in the function. + name: "bad_prometheus_format_error", + msg: `"status":"error","data":null,"errorType":"bad_data","error":"bad_data"}`, + code: 500, + status: "error", + errorType: v1.ErrServer, + }, + } { + t.Run(tc.name, func(t *testing.T) { + msg := tc.msg + if tc.expectedResp != nil { + output, err := json.Marshal(tc.expectedResp) + require.NoError(t, err) + msg = string(output) + } + writer := httptest.NewRecorder() + RespondError(logger, writer, tc.errorType, msg, tc.code) + output, err := io.ReadAll(writer.Body) + require.NoError(t, err) + var res Response + err = json.Unmarshal(output, &res) + require.NoError(t, err) + + if tc.expectedResp == nil { + require.Equal(t, tc.status, res.Status) + require.Equal(t, tc.msg, res.Error) + require.Equal(t, tc.errorType, res.ErrorType) + } else { + require.Equal(t, tc.expectedResp.Status, res.Status) + require.Equal(t, tc.expectedResp.Error, res.Error) + require.Equal(t, tc.expectedResp.ErrorType, res.ErrorType) + } + + require.Equal(t, tc.code, writer.Code) + }) + } +}