diff --git a/checker/cost.go b/checker/cost.go index b9cd8a2ed..59be751c9 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -930,6 +930,9 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate { if size, ok := c.computedSizes[e.ID()]; ok { return &size } + if size := computeExprSize(e); size != nil { + return size + } // Ensure size estimates are computed first as users may choose to override the costs that // CEL would otherwise ascribe to the type. node := astNode{expr: e, path: c.getPath(e), t: c.getType(e)} @@ -938,9 +941,6 @@ func (c *coster) computeSize(e ast.Expr) *SizeEstimate { c.computedSizes[e.ID()] = *size return size } - if size := computeExprSize(e); size != nil { - return size - } if size := computeTypeSize(c.getType(e)); size != nil { return size } diff --git a/checker/cost_test.go b/checker/cost_test.go index 2bec0e94a..f667ebe0e 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -715,6 +715,31 @@ func TestCost(t *testing.T) { expr: `self.val1 == 1.0`, wanted: FixedCostEstimate(3), }, + { + name: "bytes list max", + expr: "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')].max()", + options: []CostOption{ + OverloadCostEstimate("list_bytes_max", + func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate { + if target != nil { + // Charge 1 cost for comparing each element in the list + elCost := CostEstimate{Min: 1, Max: 1} + // If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way + // of estimating the additional comparison cost. + if elNode := listElementNode(*target); elNode != nil { + k := elNode.Type().Kind() + if k == types.StringKind || k == types.BytesKind { + sz := sizeEstimate(estimator, elNode) + elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) + } + return &CallEstimate{CostEstimate: sizeEstimate(estimator, *target).MultiplyByCost(elCost)} + } + } + return nil + }), + }, + wanted: CostEstimate{Min: 25, Max: 35}, + }, } for _, tst := range cases { @@ -745,6 +770,14 @@ func TestCost(t *testing.T) { if err != nil { t.Fatalf("environment creation error: %v", err) } + maxFunc, _ := decls.NewFunction("max", + decls.MemberOverload("list_bytes_max", + []*types.Type{types.NewListType(types.BytesType)}, + types.BytesType)) + err = e.AddFunctions(maxFunc) + if err != nil { + t.Fatalf("environment creation error: %v", err) + } err = e.AddIdents(tc.vars...) if err != nil { t.Fatalf("environment creation error: %s\n", err) @@ -773,6 +806,9 @@ func (tc testCostEstimator) EstimateSize(element AstNode) *SizeEstimate { if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok { return &SizeEstimate{Min: 0, Max: l} } + if element.Type() == types.BytesType { + return &SizeEstimate{Min: 0, Max: 12} + } return nil } @@ -793,3 +829,32 @@ func estimateSize(estimator CostEstimator, node AstNode) SizeEstimate { } return SizeEstimate{Min: 0, Max: math.MaxUint64} } + +func listElementNode(list AstNode) AstNode { + if params := list.Type().Parameters(); len(params) > 0 { + lt := params[0] + nodePath := list.Path() + if nodePath != nil { + // Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists + // for this node. + path := make([]string, len(nodePath)+1) + copy(path, nodePath) + path[len(nodePath)] = "@items" + return &astNode{path: path, t: lt, expr: nil} + } else { + // Provide just the type if no path is available so that worst case size can be looked up based on type. + return &astNode{t: lt, expr: nil} + } + } + return nil +} + +func sizeEstimate(estimator CostEstimator, t AstNode) SizeEstimate { + if sz := t.ComputedSize(); sz != nil { + return *sz + } + if sz := estimator.EstimateSize(t); sz != nil { + return *sz + } + return SizeEstimate{Min: 0, Max: math.MaxUint64} +}