diff --git a/checker/checker_test.go b/checker/checker_test.go index 44e671899..387c8bd75 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2063,7 +2063,7 @@ _&&_(_==_(list~type(list(dyn))^list, "b" )~optional_type(string)^select_optional_field )~type(optional_type(string))^type, - optional_type~type(optional_type)^optional_type + optional_type~type(optional_type(dyn))^optional_type )~bool^equals`, }, { diff --git a/common/types/optional.go b/common/types/optional.go index b8685ebf5..0d861823d 100644 --- a/common/types/optional.go +++ b/common/types/optional.go @@ -25,7 +25,7 @@ import ( var ( // OptionalType indicates the runtime type of an optional value. - OptionalType = NewOpaqueType("optional_type") + OptionalType = NewOpaqueType("optional_type", DynType) // OptionalNone is a sentinel value which is used to indicate an empty optional value. OptionalNone = &Optional{} @@ -59,6 +59,9 @@ func (o *Optional) ConvertToNative(typeDesc reflect.Type) (any, error) { if !o.HasValue() { return nil, errors.New("optional.none() dereference") } + if typeDesc == reflect.TypeFor[*Optional]() { + return o, nil + } return o.value.ConvertToNative(typeDesc) } diff --git a/common/types/optional_test.go b/common/types/optional_test.go index b89dedb29..bbc0b740d 100644 --- a/common/types/optional_test.go +++ b/common/types/optional_test.go @@ -51,6 +51,14 @@ func TestOptionalConvertToNative(t *testing.T) { if out != "hello" { t.Errorf("OptionalOf('hello').ConvertToNative(string) got %v, wanted 'hello'", out) } + optInt := OptionalOf(Int(20)) + out, err = optInt.ConvertToNative(reflect.TypeFor[*Optional]()) + if err != nil { + t.Fatalf("OptionalOf(20).ConvertToNative(optional_type) failed: %v", err) + } + if out != optInt { + t.Errorf("OptionalOf(20) got %v, wanted original value", out) + } } func TestOptionalConvertToType(t *testing.T) { diff --git a/ext/native.go b/ext/native.go index 315567745..c30f26ad3 100644 --- a/ext/native.go +++ b/ext/native.go @@ -434,10 +434,18 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) { if refType == timestampType { return cel.TimestampType, true } + if refType.Implements(refValType) { + emptyCelVal := reflect.New(refType).Elem().Interface().(ref.Val) + return emptyCelVal.Type().(*cel.Type), true + } return cel.ObjectType( fmt.Sprintf("%s.%s", simplePkgAlias(refType.PkgPath()), refType.Name()), ), true case reflect.Pointer: + if refType.Implements(refValType) { + emptyCelVal := reflect.New(refType.Elem()).Interface().(ref.Val) + return emptyCelVal.Type().(*cel.Type), true + } if refType.Implements(pbMsgInterfaceType) { pbMsg := reflect.New(refType.Elem()).Interface().(protoreflect.ProtoMessage) return cel.ObjectType(string(pbMsg.ProtoReflect().Descriptor().FullName())), true @@ -608,6 +616,10 @@ func newNativeTypes(fieldNameHandler NativeTypesFieldNameHandler, rawType reflec alreadySeen := make(map[string]struct{}) var iterateStructMembers func(reflect.Type) iterateStructMembers = func(t reflect.Type) { + if t.Implements(reflect.TypeFor[ref.Val]()) { + // skip this field since it's a CEL ref.Val instance. + return + } if k := t.Kind(); k == reflect.Pointer || k == reflect.Slice || k == reflect.Array || k == reflect.Map { iterateStructMembers(t.Elem()) return @@ -791,7 +803,8 @@ func isSupportedType(refType reflect.Type) bool { } var ( - pbMsgInterfaceType = reflect.TypeOf((*protoreflect.ProtoMessage)(nil)).Elem() - timestampType = reflect.TypeOf(time.Now()) - durationType = reflect.TypeOf(time.Nanosecond) + pbMsgInterfaceType = reflect.TypeFor[protoreflect.ProtoMessage]() + refValType = reflect.TypeFor[ref.Val]() + timestampType = reflect.TypeFor[time.Time]() + durationType = reflect.TypeFor[time.Duration]() ) diff --git a/ext/native_test.go b/ext/native_test.go index 421c9037b..cfddc35b1 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -474,7 +474,7 @@ func TestNativeTypesJsonSerialization(t *testing.T) { NestedListVal: ["first", "second"], }, StringVal: "string", - custom_name: "name", + custom_name: "name", }`, out: `{ "BoolVal": true, @@ -495,7 +495,7 @@ func TestNativeTypesJsonSerialization(t *testing.T) { ] }, "StringVal": "string", - "custom_name": "name" + "custom_name": "name" }`, additionalEnvOptions: []any{ParseStructTags(true)}, }, @@ -816,6 +816,52 @@ func TestNativeTypesWithOptional(t *testing.T) { } } +func TestNativeTypesWithCELTypedFields(t *testing.T) { + var nativeTests = []struct { + expr string + }{ + { + expr: `ext.TestRefValFieldType{optional_name: optional.of('my name')}.optional_name.orValue('') == 'my name'`, + }, + { + expr: `ext.TestRefValFieldType{IntVal: 2}.IntVal >= 1`, + }, + { + expr: `ext.TestRefValFieldType{time: timestamp('2001-01-01T00:00:00Z')}.time > timestamp('1970-01-01T00:00:00Z')`, + }, + } + env := testNativeEnv(t, cel.OptionalTypes(), ParseStructTag("cel")) + for i, tst := range nativeTests { + tc := tst + t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) { + var asts []*cel.Ast + pAst, iss := env.Parse(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, pAst) + cAst, iss := env.Check(pAst) + if iss.Err() != nil { + t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err()) + } + asts = append(asts, cAst) + for _, ast := range asts { + prg, err := env.Program(ast) + if err != nil { + t.Fatal(err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out.Value(), true) { + t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr) + } + } + }) + } +} + func TestNativeTypeConvertToType(t *testing.T) { var nativeTests = []struct { tag string @@ -1009,10 +1055,6 @@ func TestNativeTypesVersion(t *testing.T) { } } -type Custom struct { - Name string `cel:"name"` -} - func TestTypeResolutionRace(t *testing.T) { customType := reflect.TypeFor[*Custom]() env, err := cel.NewEnv( @@ -1065,6 +1107,7 @@ func testNativeEnv(t *testing.T, opts ...any) *cel.Env { } nativeOpts := []any{ reflect.ValueOf(&TestAllTypes{}), + reflect.ValueOf(&TestRefValFieldType{}), } for _, o := range opts { switch opt := o.(type) { @@ -1098,6 +1141,10 @@ func mustParseTime(t *testing.T, timestamp string) time.Time { return out } +type Custom struct { + Name string `cel:"name"` +} + type TestStructWithMultipleSameNames struct { Name string custom_name string `cel:"Name"` @@ -1152,3 +1199,9 @@ type TestMapVal struct { type TestEmbeddedTypes struct { TestNestedType `json:"embedded,omitempty"` } + +type TestRefValFieldType struct { + OptionalName *types.Optional `cel:"optional_name"` + IntVal types.Int + CELTime types.Timestamp `cel:"time"` +}