diff --git a/query/encode.go b/query/encode.go index f877b1c..d553373 100644 --- a/query/encode.go +++ b/query/encode.go @@ -30,9 +30,19 @@ import ( "time" ) -var timeType = reflect.TypeOf(time.Time{}) +var ( + timeType = reflect.TypeOf(time.Time{}) // Type representation of time.Time + encoderType = reflect.TypeOf(new(Encoder)).Elem() // Type representation of Encoder interface + structCache = make(map[reflect.Type][]fieldInfo) // Cache for storing struct field information +) -var encoderType = reflect.TypeOf(new(Encoder)).Elem() +// fieldInfo holds metadata about a struct field, including its name, tag options, index, and whether it should be omitted if empty. +type fieldInfo struct { + name string + opts tagOptions + index []int + omitempty bool +} // Encoder is an interface implemented by any type that wishes to encode // itself into URL values in a non-standard way. @@ -149,130 +159,214 @@ func Values(v interface{}) (url.Values, error) { // Embedded structs are followed recursively (using the rules defined in the // Values function documentation) breadth-first. func reflectValue(values url.Values, val reflect.Value, scope string) error { - var embedded []reflect.Value - typ := val.Type() - for i := 0; i < typ.NumField(); i++ { - sf := typ.Field(i) - if sf.PkgPath != "" && !sf.Anonymous { // unexported - continue - } + fields, err := getCachedFields(typ) + if err != nil { + return err + } - sv := val.Field(i) - tag := sf.Tag.Get("url") - if tag == "-" { - continue + for _, fi := range fields { + if err := processField(values, val, scope, fi); err != nil { + return err } - name, opts := parseTag(tag) + } - if name == "" { - if sf.Anonymous { - v := reflect.Indirect(sv) - if v.IsValid() && v.Kind() == reflect.Struct { - // save embedded struct for later processing - embedded = append(embedded, v) - continue - } - } + return nil +} - name = sf.Name - } +// processField processes a single struct field and adds its value to the url.Values. +// It handles custom encoders, slices/arrays, time.Time, and nested structs. +func processField(values url.Values, val reflect.Value, scope string, fi fieldInfo) error { + sv := val.FieldByIndex(fi.index) + if fi.omitempty && isEmptyValue(sv) { + return nil + } - if scope != "" { - name = scope + "[" + name + "]" - } + name := getFieldName(scope, fi, val.Type()) + if fi.opts.Contains("brackets") { + name = name + "[]" + } - if opts.Contains("omitempty") && isEmptyValue(sv) { - continue + if sv.Type().Implements(encoderType) { + return handleCustomEncoder(values, sv, name) + } + + sv = dereferencePointers(sv) + if sv.Kind() == reflect.Slice || sv.Kind() == reflect.Array { + return handleSliceOrArray(values, sv, name, fi, val.Type()) + } + + if sv.Type() == timeType { + values.Add(name, valueString(sv, fi.opts, val.Type().FieldByIndex(fi.index))) + return nil + } + + if sv.Kind() == reflect.Struct { + return reflectValue(values, sv, name) + } + + values.Add(name, valueString(sv, fi.opts, val.Type().FieldByIndex(fi.index))) + return nil +} + +// getFieldName constructs the URL parameter name for a field, taking into account the scope and field options. +func getFieldName(scope string, fi fieldInfo, typ reflect.Type) string { + name := fi.name + if scope != "" && !typ.FieldByIndex(fi.index).Anonymous { + name = scope + "[" + name + "]" + } + return name +} + +// handleCustomEncoder handles encoding of fields that implement the Encoder interface. +func handleCustomEncoder(values url.Values, sv reflect.Value, name string) error { + if !reflect.Indirect(sv).IsValid() && sv.Type().Elem().Implements(encoderType) { + sv = reflect.New(sv.Type().Elem()) + } + + m := sv.Interface().(Encoder) + return m.EncodeValues(name, &values) +} + +// dereferencePointers dereferences pointer values until a non-pointer value is reached. +func dereferencePointers(sv reflect.Value) reflect.Value { + for sv.Kind() == reflect.Ptr { + if sv.IsNil() { + break } + sv = sv.Elem() + } + return sv +} - if sv.Type().Implements(encoderType) { - // if sv is a nil pointer and the custom encoder is defined on a non-pointer - // method receiver, set sv to the zero value of the underlying type - if !reflect.Indirect(sv).IsValid() && sv.Type().Elem().Implements(encoderType) { - sv = reflect.New(sv.Type().Elem()) - } +// handleSliceOrArray handles encoding of slice and array fields, applying delimiter options if specified. +func handleSliceOrArray(values url.Values, sv reflect.Value, name string, fi fieldInfo, typ reflect.Type) error { + if sv.Len() == 0 { + return nil + } - m := sv.Interface().(Encoder) - if err := m.EncodeValues(name, &values); err != nil { - return err + del := getDelimiter(fi.opts, fi.index, typ) + if del != "" { + s := new(bytes.Buffer) + first := true + for i := 0; i < sv.Len(); i++ { + if first { + first = false + } else { + s.WriteString(del) } - continue + s.WriteString(valueString(sv.Index(i), fi.opts, typ.FieldByIndex(fi.index))) } - - // recursively dereference pointers. break on nil pointers - for sv.Kind() == reflect.Ptr { - if sv.IsNil() { - break + values.Add(name, s.String()) + } else { + for i := 0; i < sv.Len(); i++ { + k := name + if fi.opts.Contains("numbered") { + k = fmt.Sprintf("%s%d", name, i) } - sv = sv.Elem() + values.Add(k, valueString(sv.Index(i), fi.opts, typ.FieldByIndex(fi.index))) } + } - if sv.Kind() == reflect.Slice || sv.Kind() == reflect.Array { - if sv.Len() == 0 { - // skip if slice or array is empty - continue - } + return nil +} - var del string - if opts.Contains("comma") { - del = "," - } else if opts.Contains("space") { - del = " " - } else if opts.Contains("semicolon") { - del = ";" - } else if opts.Contains("brackets") { - name = name + "[]" - } else { - del = sf.Tag.Get("del") - } +// getCachedFields retrieves or computes the field information for a given type, using a cache to avoid repeated computation. +func getCachedFields(typ reflect.Type) ([]fieldInfo, error) { + if fields, ok := structCache[typ]; ok { + return fields, nil + } - if del != "" { - s := new(bytes.Buffer) - first := true - for i := 0; i < sv.Len(); i++ { - if first { - first = false - } else { - s.WriteString(del) - } - s.WriteString(valueString(sv.Index(i), opts, sf)) - } - values.Add(name, s.String()) - } else { - for i := 0; i < sv.Len(); i++ { - k := name - if opts.Contains("numbered") { - k = fmt.Sprintf("%s%d", name, i) - } - values.Add(k, valueString(sv.Index(i), opts, sf)) - } - } + fields, embeddedFields, err := extractFields(typ) + if err != nil { + return nil, err + } + + // Append embedded fields after non-embedded fields + fields = append(fields, embeddedFields...) + + structCache[typ] = fields + return fields, nil +} + +// extractFields extracts field information from a struct type, including handling embedded structs. +func extractFields(typ reflect.Type) ([]fieldInfo, []fieldInfo, error) { + var fields []fieldInfo + var embeddedFields []fieldInfo + + for i := 0; i < typ.NumField(); i++ { + sf := typ.Field(i) + if shouldSkipField(sf) { continue } - if sv.Type() == timeType { - values.Add(name, valueString(sv, opts, sf)) + tag := sf.Tag.Get("url") + if tag == "-" { continue } - if sv.Kind() == reflect.Struct { - if err := reflectValue(values, sv, name); err != nil { - return err + name, opts := parseTag(tag) + if name == "" { + if sf.Anonymous { + embeddedFieldsInfo, err := handleEmbeddedStruct(sf, i) + if err != nil { + return nil, nil, err + } + embeddedFields = append(embeddedFields, embeddedFieldsInfo...) + continue } - continue + name = sf.Name } - values.Add(name, valueString(sv, opts, sf)) + fields = append(fields, fieldInfo{ + name: name, + opts: opts, + index: sf.Index, + omitempty: opts.Contains("omitempty"), + }) } - for _, f := range embedded { - if err := reflectValue(values, f, scope); err != nil { - return err + return fields, embeddedFields, nil +} + +// shouldSkipField determines if a struct field should be skipped based on its visibility and other criteria. +func shouldSkipField(sf reflect.StructField) bool { + return sf.PkgPath != "" && !sf.Anonymous +} + +// handleEmbeddedStruct processes embedded struct fields, extracting their field information recursively. +func handleEmbeddedStruct(sf reflect.StructField, index int) ([]fieldInfo, error) { + var embeddedFields []fieldInfo + embeddedType := sf.Type + if embeddedType.Kind() == reflect.Ptr { + embeddedType = embeddedType.Elem() // Dereference the pointer + } + + if embeddedType.Kind() == reflect.Struct { + embeddedFieldsInfo, err := getCachedFields(embeddedType) + if err != nil { + return nil, err + } + for _, ef := range embeddedFieldsInfo { + ef.index = append([]int{index}, ef.index...) + embeddedFields = append(embeddedFields, ef) } } + return embeddedFields, nil +} - return nil +// getDelimiter returns the delimiter for slice/array fields based on the tag options. +func getDelimiter(opts tagOptions, index []int, typ reflect.Type) string { + if opts.Contains("comma") { + return "," + } else if opts.Contains("space") { + return " " + } else if opts.Contains("semicolon") { + return ";" + } else if opts.Contains("brackets") { + return "" + } + return typ.FieldByIndex(index).Tag.Get("del") } // valueString returns the string representation of a value.