Skip to content

Fix functions that take a field path #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions internal/template/groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package template

import (
"fmt"
"reflect"
"strings"

"github.com/nginx-proxy/docker-gen/internal/context"
Expand All @@ -18,7 +17,7 @@ func generalizedGroupBy(funcName string, entries interface{}, getValue func(inte

groups := make(map[string][]interface{})
for i := 0; i < entriesVal.Len(); i++ {
v := reflect.Indirect(entriesVal.Index(i)).Interface()
v := entriesVal.Index(i).Interface()
value, err := getValue(v)
if err != nil {
return nil, err
Expand Down Expand Up @@ -73,13 +72,13 @@ func groupByKeys(entries interface{}, key string) ([]string, error) {
// groupByLabel is the same as groupBy but over a given label
func groupByLabel(entries interface{}, label string) (map[string][]interface{}, error) {
getLabel := func(v interface{}) (interface{}, error) {
if container, ok := v.(context.RuntimeContainer); ok {
if container, ok := v.(*context.RuntimeContainer); ok {
if value, ok := container.Labels[label]; ok {
return value, nil
}
return nil, nil
}
return nil, fmt.Errorf("must pass an array or slice of RuntimeContainer to 'groupByLabel'; received %v", v)
return nil, fmt.Errorf("must pass an array or slice of *RuntimeContainer to 'groupByLabel'; received %v", v)
}
return generalizedGroupBy("groupByLabel", entries, getLabel, func(groups map[string][]interface{}, value interface{}, v interface{}) {
groups[value.(string)] = append(groups[value.(string)], v)
Expand Down
14 changes: 7 additions & 7 deletions internal/template/groupby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestGroupByExistingKey(t *testing.T) {
assert.Len(t, groups, 2)
assert.Len(t, groups["demo1.localhost"], 2)
assert.Len(t, groups["demo2.localhost"], 1)
assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID)
assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID)
}

func TestGroupByAfterWhere(t *testing.T) {
Expand Down Expand Up @@ -69,7 +69,7 @@ func TestGroupByAfterWhere(t *testing.T) {
assert.Len(t, groups, 2)
assert.Len(t, groups["demo1.localhost"], 1)
assert.Len(t, groups["demo2.localhost"], 1)
assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID)
assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID)
}

func TestGroupByKeys(t *testing.T) {
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestGroupByLabel(t *testing.T) {
assert.Len(t, groups["one"], 2)
assert.Len(t, groups[""], 1)
assert.Len(t, groups["two"], 1)
assert.Equal(t, "2", groups["two"][0].(context.RuntimeContainer).ID)
assert.Equal(t, "2", groups["two"][0].(*context.RuntimeContainer).ID)
}

func TestGroupByLabelError(t *testing.T) {
Expand Down Expand Up @@ -193,13 +193,13 @@ func TestGroupByMulti(t *testing.T) {
if len(groups["demo2.localhost"]) != 1 {
t.Fatalf("expected 1 got %d", len(groups["demo2.localhost"]))
}
if groups["demo2.localhost"][0].(context.RuntimeContainer).ID != "3" {
t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(context.RuntimeContainer).ID)
if groups["demo2.localhost"][0].(*context.RuntimeContainer).ID != "3" {
t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID)
}
if len(groups["demo3.localhost"]) != 1 {
t.Fatalf("expect 1 got %d", len(groups["demo3.localhost"]))
}
if groups["demo3.localhost"][0].(context.RuntimeContainer).ID != "2" {
t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(context.RuntimeContainer).ID)
if groups["demo3.localhost"][0].(*context.RuntimeContainer).ID != "2" {
t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(*context.RuntimeContainer).ID)
}
}
77 changes: 43 additions & 34 deletions internal/template/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,57 @@ package template

import (
"log"
"math"
"reflect"
"strconv"
"strings"
)

func stripPrefix(s, prefix string) string {
path := s
for {
if strings.HasPrefix(path, ".") {
path = path[1:]
continue
}
break
func deepGetImpl(v reflect.Value, path []string) interface{} {
if !v.IsValid() {
log.Printf("invalid value\n")
return nil
}
return path
}

func deepGet(item interface{}, path string) interface{} {
if path == "" {
return item
if len(path) == 0 {
return v.Interface()
}

path = stripPrefix(path, ".")
parts := strings.Split(path, ".")
itemValue := reflect.ValueOf(item)

if len(parts) > 0 {
switch itemValue.Kind() {
case reflect.Struct:
fieldValue := itemValue.FieldByName(parts[0])
if fieldValue.IsValid() {
return deepGet(fieldValue.Interface(), strings.Join(parts[1:], "."))
}
case reflect.Map:
mapValue := itemValue.MapIndex(reflect.ValueOf(parts[0]))
if mapValue.IsValid() {
return deepGet(mapValue.Interface(), strings.Join(parts[1:], "."))
}
default:
log.Printf("Can't group by %s (value %v, kind %s)\n", path, itemValue, itemValue.Kind())
if v.Kind() == reflect.Pointer {
v = v.Elem()
}
if v.Kind() == reflect.Pointer {
log.Printf("unable to descend into pointer of a pointer\n")
return nil
}
switch v.Kind() {
case reflect.Struct:
return deepGetImpl(v.FieldByName(path[0]), path[1:])
case reflect.Map:
return deepGetImpl(v.MapIndex(reflect.ValueOf(path[0])), path[1:])
case reflect.Slice, reflect.Array:
iu64, err := strconv.ParseUint(path[0], 10, 64)
if err != nil {
log.Printf("non-negative decimal number required for array/slice index, got %#v\n", path[0])
return nil
}
if iu64 > math.MaxInt {
iu64 = math.MaxInt
}
i := int(iu64)
if i >= v.Len() {
log.Printf("index %v out of bounds", i)
return nil
}
return deepGetImpl(v.Index(i), path[1:])
default:
log.Printf("unable to index by %s (value %v, kind %s)\n", path[0], v, v.Kind())
return nil
}
}

return itemValue.Interface()
func deepGet(item interface{}, path string) interface{} {
var parts []string
if path != "" {
parts = strings.Split(strings.TrimPrefix(path, "."), ".")
}
return deepGetImpl(reflect.ValueOf(item), parts)
}
44 changes: 43 additions & 1 deletion internal/template/reflect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestDeepGetSimpleDotPrefix(t *testing.T) {
item := context.RuntimeContainer{
ID: "expected",
}
value := deepGet(item, "...ID")
value := deepGet(item, ".ID")
assert.IsType(t, "", value)

assert.Equal(t, "expected", value)
Expand All @@ -51,3 +51,45 @@ func TestDeepGetMap(t *testing.T) {

assert.Equal(t, "value", value)
}

func TestDeepGet(t *testing.T) {
s := struct{ X string }{"foo"}
sp := &s

for _, tc := range []struct {
desc string
item interface{}
path string
want interface{}
}{
{
"map key empty string",
map[string]map[string]map[string]string{
"": map[string]map[string]string{
"": map[string]string{
"": "foo",
},
},
},
"...",
"foo",
},
{"struct", s, "X", "foo"},
{"pointer to struct", sp, "X", "foo"},
{"double pointer to struct", &sp, ".X", nil},
{"slice index", []string{"foo", "bar"}, "1", "bar"},
{"slice index out of bounds", []string{}, "0", nil},
{"slice index negative", []string{}, "-1", nil},
{"slice index nonnumber", []string{}, "foo", nil},
{"array index", [2]string{"foo", "bar"}, "1", "bar"},
{"array index out of bounds", [1]string{"foo"}, "1", nil},
{"array index negative", [1]string{"foo"}, "-1", nil},
{"array index nonnumber", [1]string{"foo"}, "foo", nil},
} {
t.Run(tc.desc, func(t *testing.T) {
got := deepGet(tc.item, tc.path)
assert.IsType(t, tc.want, got)
assert.Equal(t, tc.want, got)
})
}
}
2 changes: 1 addition & 1 deletion internal/template/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *sortableByKey) set(funcName string, entries interface{}) (err error) {
}
s.data = make([]interface{}, entriesVal.Len())
for i := 0; i < entriesVal.Len(); i++ {
s.data[i] = reflect.Indirect(entriesVal.Index(i)).Interface()
s.data[i] = entriesVal.Index(i).Interface()
}
return
}
Expand Down
16 changes: 8 additions & 8 deletions internal/template/sort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ func TestSortObjectsByKeysAsc(t *testing.T) {

assert.NoError(t, err)
assert.Len(t, sorted, 4)
assert.Equal(t, "foo.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "9", sorted[3].(context.RuntimeContainer).ID)
assert.Equal(t, "foo.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "9", sorted[3].(*context.RuntimeContainer).ID)

sorted, err = sortObjectsByKeysAsc(sorted, "Env.VIRTUAL_HOST")

assert.NoError(t, err)
assert.Len(t, sorted, 4)
assert.Equal(t, "foo.localhost", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "8", sorted[0].(context.RuntimeContainer).ID)
assert.Equal(t, "foo.localhost", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "8", sorted[0].(*context.RuntimeContainer).ID)
}

func TestSortObjectsByKeysDesc(t *testing.T) {
Expand Down Expand Up @@ -90,13 +90,13 @@ func TestSortObjectsByKeysDesc(t *testing.T) {

assert.NoError(t, err)
assert.Len(t, sorted, 4)
assert.Equal(t, "bar.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "1", sorted[3].(context.RuntimeContainer).ID)
assert.Equal(t, "bar.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "1", sorted[3].(*context.RuntimeContainer).ID)

sorted, err = sortObjectsByKeysDesc(sorted, "Env.VIRTUAL_HOST")

assert.NoError(t, err)
assert.Len(t, sorted, 4)
assert.Equal(t, "", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "1", sorted[0].(context.RuntimeContainer).ID)
assert.Equal(t, "", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
assert.Equal(t, "1", sorted[0].(*context.RuntimeContainer).ID)
}
2 changes: 1 addition & 1 deletion internal/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func getArrayValues(funcName string, entries interface{}) (*reflect.Value, error
kind := entriesVal.Kind()

if kind == reflect.Ptr {
entriesVal = reflect.Indirect(entriesVal)
entriesVal = entriesVal.Elem()
kind = entriesVal.Kind()
}

Expand Down
2 changes: 1 addition & 1 deletion internal/template/where.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func generalizedWhere(funcName string, entries interface{}, key string, test fun

selection := make([]interface{}, 0)
for i := 0; i < entriesVal.Len(); i++ {
v := reflect.Indirect(entriesVal.Index(i)).Interface()
v := entriesVal.Index(i).Interface()

value := deepGet(v, key)
if test(value) {
Expand Down