Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions common/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ func MaxID(a *AST) int64 {
return visitor.maxID + 1
}

// Heights computes the heights of all AST expressions and returns a map from expression id to height.
func Heights(a *AST) map[int64]int {
visitor := make(heightVisitor)
PostOrderVisit(a.Expr(), visitor)
return visitor
}

// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
func NewSourceInfo(src common.Source) *SourceInfo {
var lineOffsets []int32
Expand Down Expand Up @@ -455,3 +462,74 @@ func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) {
v.maxID = e.ID()
}
}

type heightVisitor map[int64]int

// VisitExpr computes the height of a given node as the max height of its children plus one.
//
// Identifiers and literals are treated as having a height of zero.
func (hv heightVisitor) VisitExpr(e Expr) {
Comment thread
TristonianJones marked this conversation as resolved.
// default includes IdentKind, LiteralKind
hv[e.ID()] = 0
switch e.Kind() {
case SelectKind:
hv[e.ID()] = 1 + hv[e.AsSelect().Operand().ID()]
case CallKind:
c := e.AsCall()
height := hv.maxHeight(c.Args()...)
if c.IsMemberFunction() {
tHeight := hv[c.Target().ID()]
if tHeight > height {
height = tHeight
}
}
hv[e.ID()] = 1 + height
case ListKind:
l := e.AsList()
hv[e.ID()] = 1 + hv.maxHeight(l.Elements()...)
case MapKind:
m := e.AsMap()
hv[e.ID()] = 1 + hv.maxEntryHeight(m.Entries()...)
case StructKind:
s := e.AsStruct()
hv[e.ID()] = 1 + hv.maxEntryHeight(s.Fields()...)
case ComprehensionKind:
comp := e.AsComprehension()
hv[e.ID()] = 1 + hv.maxHeight(comp.IterRange(), comp.AccuInit(), comp.LoopCondition(), comp.LoopStep(), comp.Result())
}
}

// VisitEntryExpr computes the max height of a map or struct entry and associates the height with the entry id.
func (hv heightVisitor) VisitEntryExpr(e EntryExpr) {
hv[e.ID()] = 0
switch e.Kind() {
case MapEntryKind:
me := e.AsMapEntry()
hv[e.ID()] = hv.maxHeight(me.Value(), me.Key())
case StructFieldKind:
sf := e.AsStructField()
hv[e.ID()] = hv[sf.Value().ID()]
}
}

func (hv heightVisitor) maxHeight(exprs ...Expr) int {
max := 0
for _, e := range exprs {
h := hv[e.ID()]
if h > max {
max = h
}
}
return max
}

func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int {
max := 0
for _, e := range entries {
h := hv[e.ID()]
if h > max {
max = h
}
}
return max
}
25 changes: 25 additions & 0 deletions common/ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,31 @@ func TestMaxID(t *testing.T) {
}
}

func TestHeights(t *testing.T) {
tests := []struct {
expr string
height int
}{
{`'a' == 'b'`, 1},
{`'a'.size()`, 1},
{`[1, 2].size()`, 2},
{`size('a')`, 1},
{`has({'a': 1}.a)`, 2},
{`{'a': 1}`, 1},
{`{'a': 1}['a']`, 2},
{`[1, 2, 3].exists(i, i % 2 == 1)`, 4},
{`google.expr.proto3.test.TestAllTypes{}`, 1},
{`google.expr.proto3.test.TestAllTypes{repeated_int32: [1, 2]}`, 2},
}
for _, tst := range tests {
checked := mustTypeCheck(t, tst.expr)
maxHeight := ast.Heights(checked)[checked.Expr().ID()]
if maxHeight != tst.height {
t.Errorf("ast.Heights(%q) got max height %d, wanted %d", tst.expr, maxHeight, tst.height)
}
}
}

func mockRelativeSource(t testing.TB, text string, lineOffsets []int32, baseLocation common.Location) common.Source {
t.Helper()
return &mockSource{
Expand Down
7 changes: 6 additions & 1 deletion common/ast/navigable.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,13 @@ func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) {
case StructKind:
s := expr.AsStruct()
for _, f := range s.Fields() {
visitor.VisitEntryExpr(f)
if order == preOrder {
visitor.VisitEntryExpr(f)
}
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
if order == postOrder {
visitor.VisitEntryExpr(f)
}
}
}
if order == postOrder {
Expand Down
3 changes: 2 additions & 1 deletion policy/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ func Compile(env *cel.Env, p *Policy, opts ...CompilerOption) (*cel.Ast, *cel.Is
if iss.Err() != nil {
return nil, iss
}
composer := NewRuleComposer(env, p)
// An error cannot happen when composing without supplying options
composer, _ := NewRuleComposer(env)
return composer.Compose(rule)
}

Expand Down
175 changes: 125 additions & 50 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,64 @@ import (

func TestCompile(t *testing.T) {
for _, tst := range policyTests {
t.Run(tst.name, func(t *testing.T) {
r := newRunner(t, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
tc := tst
t.Run(tc.name, func(t *testing.T) {
r := newRunner(tc.name, tc.expr, tc.parseOpts)
env, ast, iss := r.compile(t, tc.envOpts, []CompilerOption{})
if iss.Err() != nil {
t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err())
}
r.setup(t, env, ast)
r.run(t)
})
}
}

func TestRuleComposerError(t *testing.T) {
env, err := cel.NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
_, err = NewRuleComposer(env, ExpressionUnnestHeight(-1))
if err == nil || !strings.Contains(err.Error(), "invalid unnest") {
t.Errorf("NewRuleComposer() got %v, wanted 'invalid unnest'", err)
}
}

func TestRuleComposerUnnest(t *testing.T) {
for _, tst := range composerUnnestTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
r := newRunner(tc.name, tc.expr, []ParserOption{})
env, rule, iss := r.compileRule(t)
if iss.Err() != nil {
t.Fatalf("CompileRule() failed: %v", iss.Err())
}
rc, err := NewRuleComposer(env, tc.composerOpts...)
if err != nil {
t.Fatalf("NewRuleComposer() failed: %v", err)
}
ast, iss := rc.Compose(rule)
if iss.Err() != nil {
t.Fatalf("Compose(rule) failed: %v", iss.Err())
}
unparsed, err := cel.AstToString(ast)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
if normalize(unparsed) != normalize(tc.composed) {
t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed)
}
r.setup(t, env, ast)
r.run(t)
})
}
}

func TestCompileError(t *testing.T) {
for _, tst := range policyErrorTests {
_, _, iss := compile(t, tst.name, []ParserOption{}, []cel.EnvOption{}, tst.compilerOpts)
policy := parsePolicy(t, tst.name, []ParserOption{})
_, _, iss := compile(t, tst.name, policy, []cel.EnvOption{}, tst.compilerOpts)
if iss.Err() == nil {
t.Fatalf("compile(%s) did not error, wanted %s", tst.name, tst.err)
}
Expand Down Expand Up @@ -98,7 +146,8 @@ func TestMaxNestedExpressions_Error(t *testing.T) {
wantError := `ERROR: testdata/required_labels/policy.yaml:15:8: error configuring compiler option: nested expression limit must be non-negative, non-zero value: -1
| name: "required_labels"
| .......^`
_, _, iss := compile(t, policyName, []ParserOption{}, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)})
policy := parsePolicy(t, policyName, []ParserOption{})
_, _, iss := compile(t, policyName, policy, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)})
if iss.Err() == nil {
t.Fatalf("compile(%s) did not error, wanted %s", policyName, wantError)
}
Expand All @@ -109,55 +158,40 @@ func TestMaxNestedExpressions_Error(t *testing.T) {

func BenchmarkCompile(b *testing.B) {
for _, tst := range policyTests {
r := newRunner(b, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
r := newRunner(tst.name, tst.expr, tst.parseOpts)
env, ast, iss := r.compile(b, tst.envOpts, []CompilerOption{})
if iss.Err() != nil {
b.Fatalf("Compile() failed: %v", iss.Err())
}
r.setup(b, env, ast)
r.bench(b)
}
}

func newRunner(t testing.TB, name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner {
r := &runner{
func newRunner(name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner {
return &runner{
name: name,
envOpts: opts,
parseOpts: parseOpts,
expr: expr}
r.setup(t)
return r
}

type runner struct {
name string
envOpts []cel.EnvOption
parseOpts []ParserOption
compilerOpts []CompilerOption
env *cel.Env
expr string
prg cel.Program
name string
parseOpts []ParserOption
env *cel.Env
expr string
prg cel.Program
}

func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
t.Helper()
out, iss := env.Compile(expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
}
return out
func (r *runner) compile(t testing.TB, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
policy := parsePolicy(t, r.name, r.parseOpts)
return compile(t, r.name, policy, envOpts, compilerOpts)
}

func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
func (r *runner) compileRule(t testing.TB) (*cel.Env, *CompiledRule, *cel.Issues) {
t.Helper()
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
parser, err := NewParser(parseOpts...)
if err != nil {
t.Fatalf("NewParser() failed: %v", err)
}
policy, iss := parser.Parse(srcFile)
if iss.Err() != nil {
t.Fatalf("Parse() failed: %v", iss.Err())
}
if policy.name.Value != name {
t.Errorf("policy name is %v, wanted %q", policy.name, name)
}
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", r.name))
policy := parsePolicy(t, r.name, r.parseOpts)
env, err := cel.NewCustomEnv(
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
Expand All @@ -166,26 +200,17 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
// Configure any custom environment options.
env, err = env.Extend(envOpts...)
if err != nil {
t.Fatalf("env.Extend() with env options %v, failed: %v", config, err)
}
// Configure declarations
env, err = env.Extend(FromConfig(config))
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
ast, iss := Compile(env, policy, compilerOpts...)
return env, ast, iss
rule, iss := CompileRule(env, policy)
return env, rule, iss
}

func (r *runner) setup(t testing.TB) {
func (r *runner) setup(t testing.TB, env *cel.Env, ast *cel.Ast) {
t.Helper()
env, ast, iss := compile(t, r.name, r.parseOpts, r.envOpts, r.compilerOpts)
if iss.Err() != nil {
t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err())
}
pExpr, err := cel.AstToString(ast)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
Expand Down Expand Up @@ -323,6 +348,56 @@ func (r *runner) eval(t testing.TB, expr string) ref.Val {
return out
}

func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
t.Helper()
out, iss := env.Compile(expr)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
}
return out
}

func parsePolicy(t testing.TB, name string, parseOpts []ParserOption) *Policy {
t.Helper()
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
parser, err := NewParser(parseOpts...)
if err != nil {
t.Fatalf("NewParser() failed: %v", err)
}
policy, iss := parser.Parse(srcFile)
if iss.Err() != nil {
t.Fatalf("Parse() failed: %v", iss.Err())
}
if policy.name.Value != name {
t.Errorf("policy name is %v, wanted %q", policy.name, name)
}
return policy
}

func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
env, err := cel.NewCustomEnv(
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
cel.ExtendedValidations(),
ext.Bindings())
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
// Configure any custom environment options.
env, err = env.Extend(envOpts...)
if err != nil {
t.Fatalf("env.Extend() with env options %v, failed: %v", config, err)
}
// Configure declarations
env, err = env.Extend(FromConfig(config))
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
ast, iss := Compile(env, policy, compilerOpts...)
return env, ast, iss
}

func normalize(s string) string {
return strings.ReplaceAll(
strings.ReplaceAll(
Expand Down
Loading