Skip to content

Commit 616392c

Browse files
authored
Merge pull request #488 from rhansen/deepGet
Fix functions that take a field path
2 parents c7b991e + 298e650 commit 616392c

File tree

8 files changed

+107
-57
lines changed

8 files changed

+107
-57
lines changed

Diff for: internal/template/groupby.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package template
22

33
import (
44
"fmt"
5-
"reflect"
65
"strings"
76

87
"github.com/nginx-proxy/docker-gen/internal/context"
@@ -18,7 +17,7 @@ func generalizedGroupBy(funcName string, entries interface{}, getValue func(inte
1817

1918
groups := make(map[string][]interface{})
2019
for i := 0; i < entriesVal.Len(); i++ {
21-
v := reflect.Indirect(entriesVal.Index(i)).Interface()
20+
v := entriesVal.Index(i).Interface()
2221
value, err := getValue(v)
2322
if err != nil {
2423
return nil, err
@@ -73,13 +72,13 @@ func groupByKeys(entries interface{}, key string) ([]string, error) {
7372
// groupByLabel is the same as groupBy but over a given label
7473
func groupByLabel(entries interface{}, label string) (map[string][]interface{}, error) {
7574
getLabel := func(v interface{}) (interface{}, error) {
76-
if container, ok := v.(context.RuntimeContainer); ok {
75+
if container, ok := v.(*context.RuntimeContainer); ok {
7776
if value, ok := container.Labels[label]; ok {
7877
return value, nil
7978
}
8079
return nil, nil
8180
}
82-
return nil, fmt.Errorf("must pass an array or slice of RuntimeContainer to 'groupByLabel'; received %v", v)
81+
return nil, fmt.Errorf("must pass an array or slice of *RuntimeContainer to 'groupByLabel'; received %v", v)
8382
}
8483
return generalizedGroupBy("groupByLabel", entries, getLabel, func(groups map[string][]interface{}, value interface{}, v interface{}) {
8584
groups[value.(string)] = append(groups[value.(string)], v)

Diff for: internal/template/groupby_test.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func TestGroupByExistingKey(t *testing.T) {
3535
assert.Len(t, groups, 2)
3636
assert.Len(t, groups["demo1.localhost"], 2)
3737
assert.Len(t, groups["demo2.localhost"], 1)
38-
assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID)
38+
assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID)
3939
}
4040

4141
func TestGroupByAfterWhere(t *testing.T) {
@@ -69,7 +69,7 @@ func TestGroupByAfterWhere(t *testing.T) {
6969
assert.Len(t, groups, 2)
7070
assert.Len(t, groups["demo1.localhost"], 1)
7171
assert.Len(t, groups["demo2.localhost"], 1)
72-
assert.Equal(t, "3", groups["demo2.localhost"][0].(context.RuntimeContainer).ID)
72+
assert.Equal(t, "3", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID)
7373
}
7474

7575
func TestGroupByKeys(t *testing.T) {
@@ -149,7 +149,7 @@ func TestGroupByLabel(t *testing.T) {
149149
assert.Len(t, groups["one"], 2)
150150
assert.Len(t, groups[""], 1)
151151
assert.Len(t, groups["two"], 1)
152-
assert.Equal(t, "2", groups["two"][0].(context.RuntimeContainer).ID)
152+
assert.Equal(t, "2", groups["two"][0].(*context.RuntimeContainer).ID)
153153
}
154154

155155
func TestGroupByLabelError(t *testing.T) {
@@ -193,13 +193,13 @@ func TestGroupByMulti(t *testing.T) {
193193
if len(groups["demo2.localhost"]) != 1 {
194194
t.Fatalf("expected 1 got %d", len(groups["demo2.localhost"]))
195195
}
196-
if groups["demo2.localhost"][0].(context.RuntimeContainer).ID != "3" {
197-
t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(context.RuntimeContainer).ID)
196+
if groups["demo2.localhost"][0].(*context.RuntimeContainer).ID != "3" {
197+
t.Fatalf("expected 2 got %s", groups["demo2.localhost"][0].(*context.RuntimeContainer).ID)
198198
}
199199
if len(groups["demo3.localhost"]) != 1 {
200200
t.Fatalf("expect 1 got %d", len(groups["demo3.localhost"]))
201201
}
202-
if groups["demo3.localhost"][0].(context.RuntimeContainer).ID != "2" {
203-
t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(context.RuntimeContainer).ID)
202+
if groups["demo3.localhost"][0].(*context.RuntimeContainer).ID != "2" {
203+
t.Fatalf("expected 2 got %s", groups["demo3.localhost"][0].(*context.RuntimeContainer).ID)
204204
}
205205
}

Diff for: internal/template/reflect.go

+43-34
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,57 @@ package template
22

33
import (
44
"log"
5+
"math"
56
"reflect"
7+
"strconv"
68
"strings"
79
)
810

9-
func stripPrefix(s, prefix string) string {
10-
path := s
11-
for {
12-
if strings.HasPrefix(path, ".") {
13-
path = path[1:]
14-
continue
15-
}
16-
break
11+
func deepGetImpl(v reflect.Value, path []string) interface{} {
12+
if !v.IsValid() {
13+
log.Printf("invalid value\n")
14+
return nil
1715
}
18-
return path
19-
}
20-
21-
func deepGet(item interface{}, path string) interface{} {
22-
if path == "" {
23-
return item
16+
if len(path) == 0 {
17+
return v.Interface()
2418
}
25-
26-
path = stripPrefix(path, ".")
27-
parts := strings.Split(path, ".")
28-
itemValue := reflect.ValueOf(item)
29-
30-
if len(parts) > 0 {
31-
switch itemValue.Kind() {
32-
case reflect.Struct:
33-
fieldValue := itemValue.FieldByName(parts[0])
34-
if fieldValue.IsValid() {
35-
return deepGet(fieldValue.Interface(), strings.Join(parts[1:], "."))
36-
}
37-
case reflect.Map:
38-
mapValue := itemValue.MapIndex(reflect.ValueOf(parts[0]))
39-
if mapValue.IsValid() {
40-
return deepGet(mapValue.Interface(), strings.Join(parts[1:], "."))
41-
}
42-
default:
43-
log.Printf("Can't group by %s (value %v, kind %s)\n", path, itemValue, itemValue.Kind())
19+
if v.Kind() == reflect.Pointer {
20+
v = v.Elem()
21+
}
22+
if v.Kind() == reflect.Pointer {
23+
log.Printf("unable to descend into pointer of a pointer\n")
24+
return nil
25+
}
26+
switch v.Kind() {
27+
case reflect.Struct:
28+
return deepGetImpl(v.FieldByName(path[0]), path[1:])
29+
case reflect.Map:
30+
return deepGetImpl(v.MapIndex(reflect.ValueOf(path[0])), path[1:])
31+
case reflect.Slice, reflect.Array:
32+
iu64, err := strconv.ParseUint(path[0], 10, 64)
33+
if err != nil {
34+
log.Printf("non-negative decimal number required for array/slice index, got %#v\n", path[0])
35+
return nil
36+
}
37+
if iu64 > math.MaxInt {
38+
iu64 = math.MaxInt
39+
}
40+
i := int(iu64)
41+
if i >= v.Len() {
42+
log.Printf("index %v out of bounds", i)
43+
return nil
4444
}
45+
return deepGetImpl(v.Index(i), path[1:])
46+
default:
47+
log.Printf("unable to index by %s (value %v, kind %s)\n", path[0], v, v.Kind())
4548
return nil
4649
}
50+
}
4751

48-
return itemValue.Interface()
52+
func deepGet(item interface{}, path string) interface{} {
53+
var parts []string
54+
if path != "" {
55+
parts = strings.Split(strings.TrimPrefix(path, "."), ".")
56+
}
57+
return deepGetImpl(reflect.ValueOf(item), parts)
4958
}

Diff for: internal/template/reflect_test.go

+43-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestDeepGetSimpleDotPrefix(t *testing.T) {
3434
item := context.RuntimeContainer{
3535
ID: "expected",
3636
}
37-
value := deepGet(item, "...ID")
37+
value := deepGet(item, ".ID")
3838
assert.IsType(t, "", value)
3939

4040
assert.Equal(t, "expected", value)
@@ -51,3 +51,45 @@ func TestDeepGetMap(t *testing.T) {
5151

5252
assert.Equal(t, "value", value)
5353
}
54+
55+
func TestDeepGet(t *testing.T) {
56+
s := struct{ X string }{"foo"}
57+
sp := &s
58+
59+
for _, tc := range []struct {
60+
desc string
61+
item interface{}
62+
path string
63+
want interface{}
64+
}{
65+
{
66+
"map key empty string",
67+
map[string]map[string]map[string]string{
68+
"": map[string]map[string]string{
69+
"": map[string]string{
70+
"": "foo",
71+
},
72+
},
73+
},
74+
"...",
75+
"foo",
76+
},
77+
{"struct", s, "X", "foo"},
78+
{"pointer to struct", sp, "X", "foo"},
79+
{"double pointer to struct", &sp, ".X", nil},
80+
{"slice index", []string{"foo", "bar"}, "1", "bar"},
81+
{"slice index out of bounds", []string{}, "0", nil},
82+
{"slice index negative", []string{}, "-1", nil},
83+
{"slice index nonnumber", []string{}, "foo", nil},
84+
{"array index", [2]string{"foo", "bar"}, "1", "bar"},
85+
{"array index out of bounds", [1]string{"foo"}, "1", nil},
86+
{"array index negative", [1]string{"foo"}, "-1", nil},
87+
{"array index nonnumber", [1]string{"foo"}, "foo", nil},
88+
} {
89+
t.Run(tc.desc, func(t *testing.T) {
90+
got := deepGet(tc.item, tc.path)
91+
assert.IsType(t, tc.want, got)
92+
assert.Equal(t, tc.want, got)
93+
})
94+
}
95+
}

Diff for: internal/template/sort.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (s *sortableByKey) set(funcName string, entries interface{}) (err error) {
4747
}
4848
s.data = make([]interface{}, entriesVal.Len())
4949
for i := 0; i < entriesVal.Len(); i++ {
50-
s.data[i] = reflect.Indirect(entriesVal.Index(i)).Interface()
50+
s.data[i] = entriesVal.Index(i).Interface()
5151
}
5252
return
5353
}

Diff for: internal/template/sort_test.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ func TestSortObjectsByKeysAsc(t *testing.T) {
4949

5050
assert.NoError(t, err)
5151
assert.Len(t, sorted, 4)
52-
assert.Equal(t, "foo.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
53-
assert.Equal(t, "9", sorted[3].(context.RuntimeContainer).ID)
52+
assert.Equal(t, "foo.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
53+
assert.Equal(t, "9", sorted[3].(*context.RuntimeContainer).ID)
5454

5555
sorted, err = sortObjectsByKeysAsc(sorted, "Env.VIRTUAL_HOST")
5656

5757
assert.NoError(t, err)
5858
assert.Len(t, sorted, 4)
59-
assert.Equal(t, "foo.localhost", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
60-
assert.Equal(t, "8", sorted[0].(context.RuntimeContainer).ID)
59+
assert.Equal(t, "foo.localhost", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
60+
assert.Equal(t, "8", sorted[0].(*context.RuntimeContainer).ID)
6161
}
6262

6363
func TestSortObjectsByKeysDesc(t *testing.T) {
@@ -90,13 +90,13 @@ func TestSortObjectsByKeysDesc(t *testing.T) {
9090

9191
assert.NoError(t, err)
9292
assert.Len(t, sorted, 4)
93-
assert.Equal(t, "bar.localhost", sorted[0].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
94-
assert.Equal(t, "1", sorted[3].(context.RuntimeContainer).ID)
93+
assert.Equal(t, "bar.localhost", sorted[0].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
94+
assert.Equal(t, "1", sorted[3].(*context.RuntimeContainer).ID)
9595

9696
sorted, err = sortObjectsByKeysDesc(sorted, "Env.VIRTUAL_HOST")
9797

9898
assert.NoError(t, err)
9999
assert.Len(t, sorted, 4)
100-
assert.Equal(t, "", sorted[3].(context.RuntimeContainer).Env["VIRTUAL_HOST"])
101-
assert.Equal(t, "1", sorted[0].(context.RuntimeContainer).ID)
100+
assert.Equal(t, "", sorted[3].(*context.RuntimeContainer).Env["VIRTUAL_HOST"])
101+
assert.Equal(t, "1", sorted[0].(*context.RuntimeContainer).ID)
102102
}

Diff for: internal/template/template.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func getArrayValues(funcName string, entries interface{}) (*reflect.Value, error
3030
kind := entriesVal.Kind()
3131

3232
if kind == reflect.Ptr {
33-
entriesVal = reflect.Indirect(entriesVal)
33+
entriesVal = entriesVal.Elem()
3434
kind = entriesVal.Kind()
3535
}
3636

Diff for: internal/template/where.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func generalizedWhere(funcName string, entries interface{}, key string, test fun
1818

1919
selection := make([]interface{}, 0)
2020
for i := 0; i < entriesVal.Len(); i++ {
21-
v := reflect.Indirect(entriesVal.Index(i)).Interface()
21+
v := entriesVal.Index(i).Interface()
2222

2323
value := deepGet(v, key)
2424
if test(value) {

0 commit comments

Comments
 (0)