Skip to content

Commit bd1ec92

Browse files
Fix two-variable comprehension pruning (#1083)
* Fix two-variable comprehension pruning * Ensure only cel.bind() comprehensions are pruned
1 parent 6202a67 commit bd1ec92

5 files changed

Lines changed: 346 additions & 27 deletions

File tree

ext/comprehensions_test.go

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"testing"
2121

2222
"github.com/google/cel-go/cel"
23+
"github.com/google/cel-go/common/types"
24+
"github.com/google/cel-go/interpreter"
2325
)
2426

2527
func TestTwoVarComprehensions(t *testing.T) {
@@ -359,6 +361,258 @@ func TestTwoVarComprehensionsVersion(t *testing.T) {
359361
}
360362
}
361363

364+
func TestTwoVarComprehensionsUnparse(t *testing.T) {
365+
tests := []struct {
366+
name string
367+
expr string
368+
unparsed string
369+
}{
370+
{
371+
name: "transform map entry",
372+
expr: `[0, 0u].transformMapEntry(i, v, {v: i})`,
373+
unparsed: `[0, 0u].transformMapEntry(i, v, {v: i})`,
374+
},
375+
{
376+
name: "transform map",
377+
expr: `{'a': 'world', 'b': 'hello'}.transformMap(i, v, i == 'a' ? v.upperAscii() : v)`,
378+
unparsed: `{"a": "world", "b": "hello"}.transformMap(i, v, (i == "a") ? v.upperAscii() : v)`,
379+
},
380+
{
381+
name: "transform list",
382+
expr: `[1.0, 2.0, 2.0].transformList(i, v, i / 2.0 == 1.0)`,
383+
unparsed: `[1.0, 2.0, 2.0].transformList(i, v, i / 2.0 == 1.0)`,
384+
},
385+
{
386+
name: "existsOne",
387+
expr: `{'a': 'b', 'c': 'd'}.existsOne(k, v, k == 'b' || v == 'b')`,
388+
unparsed: `{"a": "b", "c": "d"}.existsOne(k, v, k == "b" || v == "b")`,
389+
},
390+
{
391+
name: "exists",
392+
expr: `{'a': 'b', 'c': 'd'}.exists(k, v, k == 'b' || v == 'b')`,
393+
unparsed: `{"a": "b", "c": "d"}.exists(k, v, k == "b" || v == "b")`,
394+
},
395+
{
396+
name: "all",
397+
expr: `[null, null, 'hello', string].all(i, v, i == 0 || type(v) != int)`,
398+
unparsed: `[null, null, "hello", string].all(i, v, i == 0 || type(v) != int)`,
399+
},
400+
}
401+
env := testCompreEnv(t)
402+
for _, tst := range tests {
403+
tc := tst
404+
t.Run(tc.name, func(t *testing.T) {
405+
ast, iss := env.Parse(tc.expr)
406+
if iss.Err() != nil {
407+
t.Fatalf("env.Parse(%q) failed: %v", tc.expr, iss.Err())
408+
}
409+
unparsed, err := cel.AstToString(ast)
410+
if err != nil {
411+
t.Fatalf("cel.AstToString() failed: %v", err)
412+
}
413+
if unparsed != tc.unparsed {
414+
t.Errorf("cel.AstToString() got %q, wanted %q", unparsed, tc.unparsed)
415+
}
416+
})
417+
}
418+
}
419+
420+
func TestTwoVarComprehensionsResidualAST(t *testing.T) {
421+
tests := []struct {
422+
name string
423+
in map[string]any
424+
varOpts []cel.EnvOption
425+
unks []*interpreter.AttributePattern
426+
expr string
427+
residual string
428+
}{
429+
{
430+
name: "transform map entry residual compare",
431+
varOpts: []cel.EnvOption{
432+
cel.Variable("x", cel.ListType(cel.DynType)),
433+
cel.Variable("y", cel.IntType),
434+
},
435+
in: map[string]any{
436+
"x": []any{0, uint(1)},
437+
},
438+
unks: []*interpreter.AttributePattern{cel.AttributePattern("y")},
439+
expr: `x.transformMapEntry(i, v, {v: i}).size() < y`,
440+
residual: `2 < y`,
441+
},
442+
{
443+
name: "transform map entry residual transform",
444+
varOpts: []cel.EnvOption{
445+
cel.Variable("x", cel.ListType(cel.DynType)),
446+
cel.Variable("y", cel.IntType),
447+
},
448+
in: map[string]any{
449+
"x": []any{0, uint(1)},
450+
},
451+
unks: []*interpreter.AttributePattern{cel.AttributePattern("y")},
452+
expr: `x.transformMapEntry(i, v, i < y, {v: i})`,
453+
residual: `[0, 1u].transformMapEntry(i, v, i < y, {v: i})`,
454+
},
455+
{
456+
name: "nested exists unknown inner range",
457+
varOpts: []cel.EnvOption{
458+
cel.Variable("x", cel.ListType(cel.IntType)),
459+
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
460+
},
461+
in: map[string]any{
462+
"x": []any{1, 2, 3},
463+
},
464+
unks: []*interpreter.AttributePattern{cel.AttributePattern("y")},
465+
expr: `x.exists(val, y.exists(key, _, key == val))`,
466+
residual: `[1, 2, 3].exists(val, y.exists(key, _, key == val))`,
467+
},
468+
{
469+
name: "nested exists unknown inner range",
470+
varOpts: []cel.EnvOption{
471+
cel.Variable("x", cel.ListType(cel.IntType)),
472+
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
473+
},
474+
in: map[string]any{
475+
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
476+
},
477+
unks: []*interpreter.AttributePattern{cel.AttributePattern("x")},
478+
expr: `x.exists(val, y.exists(key, _, key == val))`,
479+
residual: `x.exists(val, y.exists(key, _, key == val))`,
480+
},
481+
{
482+
name: "nested exists unknown outer range with extra predicate",
483+
varOpts: []cel.EnvOption{
484+
cel.Variable("x", cel.ListType(cel.IntType)),
485+
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
486+
},
487+
in: map[string]any{
488+
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
489+
},
490+
unks: []*interpreter.AttributePattern{cel.AttributePattern("x")},
491+
expr: `x.exists(val, y.exists(key, _, key == val)) && y.all(key, val, val.startsWith('h'))`,
492+
residual: `x.exists(val, y.exists(key, _, key == val))`,
493+
},
494+
{
495+
name: "nested exists partial unknown outer range",
496+
varOpts: []cel.EnvOption{
497+
cel.Variable("x", cel.ListType(cel.IntType)),
498+
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
499+
},
500+
in: map[string]any{
501+
"x": []int{42, 0, 43},
502+
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
503+
},
504+
unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(1)},
505+
expr: `x.exists(val, y.exists(key, _, key == val)) || x[0] == 0 || x[1] == 1 || x[2] == 2`,
506+
residual: `x.exists(val, y.exists(key, _, key == val)) || x[1] == 1`,
507+
},
508+
{
509+
name: "nested exists partial unknown outer range with optionals",
510+
varOpts: []cel.EnvOption{
511+
cel.Variable("x", cel.ListType(cel.IntType)),
512+
cel.Variable("y", cel.MapType(cel.IntType, cel.DynType)),
513+
},
514+
in: map[string]any{
515+
"x": []int{42, 0, 43},
516+
"y": map[int]string{1: "hi", 2: "hello", 3: "howdy"},
517+
},
518+
unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(1)},
519+
expr: `x.exists(val, y.exists(key, _, key == val)) || (x[?0].hasValue() && x[?1].hasValue())`,
520+
residual: `x.exists(val, y.exists(key, _, key == val)) || x[?1].hasValue()`,
521+
},
522+
{
523+
name: "inner value partial unknown two-var",
524+
varOpts: []cel.EnvOption{
525+
cel.Variable("x", cel.ListType(cel.StringType)),
526+
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
527+
},
528+
in: map[string]any{
529+
"x": []string{"howdy", "hello", "hi"},
530+
"y": map[int]string{0: "hi", 1: "hello", 2: "howdy"},
531+
},
532+
unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)},
533+
expr: `x.exists(key, val, y[?key] == optional.of(val))`,
534+
residual: `["howdy", "hello", "hi"].exists(key, val, y[?key] == optional.of(val))`,
535+
},
536+
{
537+
name: "inner value partial unknown one-var",
538+
varOpts: []cel.EnvOption{
539+
cel.Variable("x", cel.ListType(cel.StringType)),
540+
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
541+
},
542+
in: map[string]any{
543+
"x": []string{"howdy"},
544+
"y": map[int]string{0: "hello"},
545+
},
546+
unks: []*interpreter.AttributePattern{cel.AttributePattern("x").QualInt(0)},
547+
expr: `y.exists(key, y[?key] == x[?key])`,
548+
residual: `{0: "hello"}.exists(key, y[?key] == x[?key])`,
549+
},
550+
{
551+
name: "simple bind",
552+
varOpts: []cel.EnvOption{
553+
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
554+
},
555+
in: map[string]any{
556+
"y": map[int]string{0: "hi", 1: "hello", 2: "howdy"},
557+
},
558+
unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)},
559+
expr: `cel.bind(z, y[0], z + y[1])`,
560+
residual: `cel.bind(z, "hi", "hi" + y[1])`,
561+
},
562+
{
563+
name: "bind with comprehension",
564+
varOpts: []cel.EnvOption{
565+
cel.Variable("x", cel.ListType(cel.StringType)),
566+
cel.Variable("y", cel.MapType(cel.IntType, cel.StringType)),
567+
},
568+
in: map[string]any{
569+
"x": []string{"hi", "hello", "howdy"},
570+
"y": map[int]string{0: "hi", 1: "hello", 2: "howdy"},
571+
},
572+
unks: []*interpreter.AttributePattern{cel.AttributePattern("y").QualInt(1)},
573+
expr: `cel.bind(z, y[0], x.all(i, val, val == z || optional.of(val) == y[?i]))`,
574+
residual: `cel.bind(z, "hi", ["hi", "hello", "howdy"].all(i, val, val == z || optional.of(val) == y[?i]))`,
575+
},
576+
}
577+
for _, tst := range tests {
578+
tc := tst
579+
t.Run(tc.name, func(t *testing.T) {
580+
env := testCompreEnv(t, tc.varOpts...)
581+
ast, iss := env.Compile(tc.expr)
582+
if iss.Err() != nil {
583+
t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err())
584+
}
585+
prg, err := env.Program(ast,
586+
cel.EvalOptions(cel.OptTrackState, cel.OptPartialEval))
587+
if err != nil {
588+
t.Fatalf("env.Program() failed: %v", err)
589+
}
590+
unkVars, err := cel.PartialVars(tc.in, tc.unks...)
591+
if err != nil {
592+
t.Fatalf("PartialVars() failed: %v", err)
593+
}
594+
out, det, err := prg.Eval(unkVars)
595+
if !types.IsUnknown(out) {
596+
t.Fatalf("got %v, expected unknown", out)
597+
}
598+
if err != nil {
599+
t.Fatalf("prg.Eval() failed: %v", err)
600+
}
601+
residual, err := env.ResidualAst(ast, det)
602+
if err != nil {
603+
t.Fatalf("env.ResidualAst() failed: %v", err)
604+
}
605+
expr, err := cel.AstToString(residual)
606+
if err != nil {
607+
t.Fatalf("cel.AstToString() failed: %v", err)
608+
}
609+
if expr != tc.residual {
610+
t.Errorf("got expr: %s, wanted %s", expr, tc.residual)
611+
}
612+
})
613+
}
614+
}
615+
362616
func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
363617
t.Helper()
364618
baseOpts := []cel.EnvOption{

interpreter/activation.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ type PartialActivation interface {
156156
UnknownAttributePatterns() []*AttributePattern
157157
}
158158

159+
// partialActivationConverter indicates whether an Activation implementation supports conversion to a PartialActivation
160+
type partialActivationConverter interface {
161+
asPartialActivation() (PartialActivation, bool)
162+
}
163+
159164
// partActivation is the default implementations of the PartialActivation interface.
160165
type partActivation struct {
161166
Activation
@@ -166,3 +171,20 @@ type partActivation struct {
166171
func (a *partActivation) UnknownAttributePatterns() []*AttributePattern {
167172
return a.unknowns
168173
}
174+
175+
// asPartialActivation returns the partActivation as a PartialActivation interface.
176+
func (a *partActivation) asPartialActivation() (PartialActivation, bool) {
177+
return a, true
178+
}
179+
180+
func asPartialActivation(vars Activation) (PartialActivation, bool) {
181+
// Only internal activation instances may implement this interface
182+
if pv, ok := vars.(partialActivationConverter); ok {
183+
return pv.asPartialActivation()
184+
}
185+
// Since Activations may be hierarchical, test whether a parent converts to a PartialActivation
186+
if vars.Parent() != nil {
187+
return asPartialActivation(vars.Parent())
188+
}
189+
return nil, false
190+
}

interpreter/attribute_patterns.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ func (m *attributeMatcher) AddQualifier(qual Qualifier) (Attribute, error) {
358358
func (m *attributeMatcher) Resolve(vars Activation) (any, error) {
359359
id := m.NamespacedAttribute.ID()
360360
// Bug in how partial activation is resolved, should search parents as well.
361-
partial, isPartial := toPartialActivation(vars)
361+
partial, isPartial := asPartialActivation(vars)
362362
if isPartial {
363363
unk, err := m.fac.matchesUnknownPatterns(
364364
partial,
@@ -384,14 +384,3 @@ func (m *attributeMatcher) Qualify(vars Activation, obj any) (any, error) {
384384
func (m *attributeMatcher) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) {
385385
return attrQualifyIfPresent(m.fac, vars, obj, m, presenceOnly)
386386
}
387-
388-
func toPartialActivation(vars Activation) (PartialActivation, bool) {
389-
pv, ok := vars.(PartialActivation)
390-
if ok {
391-
return pv, true
392-
}
393-
if vars.Parent() != nil {
394-
return toPartialActivation(vars.Parent())
395-
}
396-
return nil, false
397-
}

interpreter/interpretable.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,9 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
762762
defer releaseFolder(f)
763763

764764
foldRange := fold.iterRange.Eval(ctx)
765+
if types.IsUnknownOrError(foldRange) {
766+
return foldRange
767+
}
765768
if fold.iterVar2 != "" {
766769
var foldable traits.Foldable
767770
switch r := foldRange.(type) {
@@ -1363,6 +1366,26 @@ func (f *folder) Parent() Activation {
13631366
return f.activation
13641367
}
13651368

1369+
// UnknownAttributePatterns implements the PartialActivation interface returning the unknown patterns
1370+
// if they were provided to the input activation, or an empty set if the proxied activation is not partial.
1371+
func (f *folder) UnknownAttributePatterns() []*AttributePattern {
1372+
if pv, ok := f.activation.(partialActivationConverter); ok {
1373+
if partial, isPartial := pv.asPartialActivation(); isPartial {
1374+
return partial.UnknownAttributePatterns()
1375+
}
1376+
}
1377+
return []*AttributePattern{}
1378+
}
1379+
1380+
func (f *folder) asPartialActivation() (PartialActivation, bool) {
1381+
if pv, ok := f.activation.(partialActivationConverter); ok {
1382+
if _, isPartial := pv.asPartialActivation(); isPartial {
1383+
return f, true
1384+
}
1385+
}
1386+
return nil, false
1387+
}
1388+
13661389
// evalResult computes the final result of the fold after all entries have been folded and accumulated.
13671390
func (f *folder) evalResult() ref.Val {
13681391
f.computeResult = true

0 commit comments

Comments
 (0)