Skip to content

Commit 6acfd04

Browse files
Ensure receiver and global matches cost estimates agree
1 parent f0ffa7e commit 6acfd04

4 files changed

Lines changed: 31 additions & 5 deletions

File tree

checker/cost.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -791,18 +791,26 @@ func (c *coster) functionCost(e ast.Expr, function, overloadID string, target *A
791791
return CallEstimate{CostEstimate: c.sizeOrUnknown(args[1]).MultiplyByCostFactor(1).Add(argCostSum())}
792792
}
793793
// O(nm) functions
794-
case overloads.MatchesString:
794+
case overloads.Matches, overloads.MatchesString:
795795
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
796-
if target != nil && len(args) == 1 {
796+
var strNode, regexNode AstNode
797+
if overloadID == overloads.MatchesString && target != nil && len(args) == 1 {
798+
strNode = *target
799+
regexNode = args[0]
800+
} else if overloadID == overloads.Matches && target == nil && len(args) == 2 {
801+
strNode = args[0]
802+
regexNode = args[1]
803+
}
804+
if strNode != nil && regexNode != nil {
797805
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
798806
// in case where string is empty but regex is still expensive.
799-
strCost := c.sizeOrUnknown(*target).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
807+
strCost := c.sizeOrUnknown(strNode).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
800808
// We don't know how many expressions are in the regex, just the string length (a huge
801809
// improvement here would be to somehow get a count the number of expressions in the regex or
802810
// how many states are in the regex state machine and use that to measure regex cost).
803811
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
804812
// in length.
805-
regexCost := c.sizeOrUnknown(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
813+
regexCost := c.sizeOrUnknown(regexNode).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
806814
return CallEstimate{CostEstimate: strCost.Multiply(regexCost).Add(argCostSum())}
807815
}
808816
case overloads.ContainsString:

checker/cost_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,15 @@ func TestCost(t *testing.T) {
314314
hints: map[string]uint64{"input": 500},
315315
wanted: CostEstimate{Min: 3, Max: 103},
316316
},
317+
{
318+
name: "matches global",
319+
expr: `matches(input, '\\d+a\\d+b')`,
320+
vars: []*decls.VariableDecl{
321+
decls.NewVariable("input", types.StringType),
322+
},
323+
hints: map[string]uint64{"input": 500},
324+
wanted: CostEstimate{Min: 3, Max: 103},
325+
},
317326
{
318327
name: "startsWith",
319328
expr: `input.startsWith(arg1)`,

interpreter/runtimecost.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func (c *CostTracker) costCall(call InterpretableCall, args []ref.Val, result re
285285
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
286286
cost += uint64(math.Ceil(float64(actualSize(args[0])+actualSize(args[1])) * common.StringTraversalCostFactor))
287287
// O(nm) functions
288-
case overloads.MatchesString:
288+
case overloads.Matches, overloads.MatchesString:
289289
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
290290
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
291291
// in case where string is empty but regex is still expensive.

interpreter/runtimecost_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,15 @@ func TestRuntimeCost(t *testing.T) {
606606
want: 103,
607607
in: map[string]any{"input": string(randSeq(500)), "arg1": string(randSeq(500))},
608608
},
609+
{
610+
name: "matches global",
611+
expr: `matches(input, '\\d+a\\d+b')`,
612+
vars: []*decls.VariableDecl{
613+
decls.NewVariable("input", types.StringType),
614+
},
615+
want: 103,
616+
in: map[string]any{"input": string(randSeq(500))},
617+
},
609618
{
610619
name: "startsWith",
611620
expr: `input.startsWith(arg1)`,

0 commit comments

Comments
 (0)