Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2063,7 +2063,7 @@ _&&_(_==_(list~type(list(dyn))^list,
"b"
)~optional_type(string)^select_optional_field
)~type(optional_type(string))^type,
optional_type~type(optional_type)^optional_type
optional_type~type(optional_type(dyn))^optional_type
)~bool^equals`,
},
{
Expand Down
5 changes: 4 additions & 1 deletion common/types/optional.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

var (
// OptionalType indicates the runtime type of an optional value.
OptionalType = NewOpaqueType("optional_type")
OptionalType = NewOpaqueType("optional_type", DynType)

// OptionalNone is a sentinel value which is used to indicate an empty optional value.
OptionalNone = &Optional{}
Expand Down Expand Up @@ -59,6 +59,9 @@ func (o *Optional) ConvertToNative(typeDesc reflect.Type) (any, error) {
if !o.HasValue() {
return nil, errors.New("optional.none() dereference")
}
if typeDesc == reflect.TypeFor[*Optional]() {
return o, nil
}
return o.value.ConvertToNative(typeDesc)
}

Expand Down
8 changes: 8 additions & 0 deletions common/types/optional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ func TestOptionalConvertToNative(t *testing.T) {
if out != "hello" {
t.Errorf("OptionalOf('hello').ConvertToNative(string) got %v, wanted 'hello'", out)
}
optInt := OptionalOf(Int(20))
out, err = optInt.ConvertToNative(reflect.TypeFor[*Optional]())
if err != nil {
t.Fatalf("OptionalOf(20).ConvertToNative(optional_type) failed: %v", err)
}
if out != optInt {
t.Errorf("OptionalOf(20) got %v, wanted original value", out)
}
}

func TestOptionalConvertToType(t *testing.T) {
Expand Down
19 changes: 16 additions & 3 deletions ext/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,18 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) {
if refType == timestampType {
return cel.TimestampType, true
}
if refType.Implements(refValType) {
emptyCelVal := reflect.New(refType).Elem().Interface().(ref.Val)
return emptyCelVal.Type().(*cel.Type), true
}
return cel.ObjectType(
fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()),
), true
case reflect.Pointer:
if refType.Implements(refValType) {
emptyCelVal := reflect.New(refType.Elem()).Interface().(ref.Val)
return emptyCelVal.Type().(*cel.Type), true
}
if refType.Implements(pbMsgInterfaceType) {
pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage)
return cel.ObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true
Expand Down Expand Up @@ -608,6 +616,10 @@ func newNativeTypes(fieldNameHandler NativeTypesFieldNameHandler, rawType reflec
alreadySeen := make(map[string]struct{})
var iterateStructMembers func(reflect.Type)
iterateStructMembers = func(t reflect.Type) {
if t.Implements(reflect.TypeFor[ref.Val]()) {
// skip this field since it's a CEL ref.Val instance.
return
}
if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map {
iterateStructMembers(t.Elem())
return
Expand Down Expand Up @@ -791,7 +803,8 @@ func isSupportedType(refType reflect.Type) bool {
}

var (
pbMsgInterfaceType = reflect.TypeOf((*protoreflect.ProtoMessage)(nil)).Elem()
timestampType = reflect.TypeOf(time.Now())
durationType = reflect.TypeOf(time.Nanosecond)
pbMsgInterfaceType = reflect.TypeFor[protoreflect.ProtoMessage]()
refValType = reflect.TypeFor[ref.Val]()
timestampType = reflect.TypeFor[time.Time]()
durationType = reflect.TypeFor[time.Duration]()
)
65 changes: 59 additions & 6 deletions ext/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ func TestNativeTypesJsonSerialization(t *testing.T) {
NestedListVal: ["first", "second"],
},
StringVal: "string",
custom_name: "name",
custom_name: "name",
}`,
out: `{
"BoolVal": true,
Expand All @@ -495,7 +495,7 @@ func TestNativeTypesJsonSerialization(t *testing.T) {
]
},
"StringVal": "string",
"custom_name": "name"
"custom_name": "name"
}`,
additionalEnvOptions: []any{ParseStructTags(true)},
},
Expand Down Expand Up @@ -816,6 +816,52 @@ func TestNativeTypesWithOptional(t *testing.T) {
}
}

func TestNativeTypesWithCELTypedFields(t *testing.T) {
var nativeTests = []struct {
expr string
}{
{
expr: `ext.TestRefValFieldType{optional_name: optional.of('my name')}.optional_name.orValue('') == 'my name'`,
},
{
expr: `ext.TestRefValFieldType{IntVal: 2}.IntVal >= 1`,
},
{
expr: `ext.TestRefValFieldType{time: timestamp('2001-01-01T00:00:00Z')}.time > timestamp('1970-01-01T00:00:00Z')`,
},
}
env := testNativeEnv(t, cel.OptionalTypes(), ParseStructTag("cel"))
for i, tst := range nativeTests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
var asts []*cel.Ast
pAst, iss := env.Parse(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, pAst)
cAst, iss := env.Check(pAst)
if iss.Err() != nil {
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, cAst)
for _, ast := range asts {
prg, err := env.Program(ast)
if err != nil {
t.Fatal(err)
}
out, _, err := prg.Eval(cel.NoVars())
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(out.Value(), true) {
t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr)
}
}
})
}
}

func TestNativeTypeConvertToType(t *testing.T) {
var nativeTests = []struct {
tag string
Expand Down Expand Up @@ -1009,10 +1055,6 @@ func TestNativeTypesVersion(t *testing.T) {
}
}

type Custom struct {
Name string `cel:"name"`
}

func TestTypeResolutionRace(t *testing.T) {
customType := reflect.TypeFor[*Custom]()
env, err := cel.NewEnv(
Expand Down Expand Up @@ -1065,6 +1107,7 @@ func testNativeEnv(t *testing.T, opts ...any) *cel.Env {
}
nativeOpts := []any{
reflect.ValueOf(&TestAllTypes{}),
reflect.ValueOf(&TestRefValFieldType{}),
}
for _, o := range opts {
switch opt := o.(type) {
Expand Down Expand Up @@ -1098,6 +1141,10 @@ func mustParseTime(t *testing.T, timestamp string) time.Time {
return out
}

type Custom struct {
Name string `cel:"name"`
}

type TestStructWithMultipleSameNames struct {
Name string
custom_name string `cel:"Name"`
Expand Down Expand Up @@ -1152,3 +1199,9 @@ type TestMapVal struct {
type TestEmbeddedTypes struct {
TestNestedType `json:"embedded,omitempty"`
}

type TestRefValFieldType struct {
OptionalName *types.Optional `cel:"optional_name"`
IntVal types.Int
CELTime types.Timestamp `cel:"time"`
}