Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion templates/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ package templates
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
"text/template"
)
Expand All @@ -26,7 +29,7 @@ var basicFunctions = template.FuncMap{
return strings.TrimSpace(buf.String())
},
"split": strings.Split,
"join": strings.Join,
"join": joinElements,
"title": strings.Title, //nolint:nolintlint,staticcheck // strings.Title is deprecated, but we only use it for ASCII, so replacing with golang.org/x/text is out of scope
"lower": strings.ToLower,
"upper": strings.ToUpper,
Expand Down Expand Up @@ -103,3 +106,40 @@ func truncateWithLength(source string, length int) string {
}
return source[:length]
}

// joinElements joins a slice of items with the given separator. It uses
// [strings.Join] if it's a slice of strings, otherwise uses [fmt.Sprint]
// to join each item to the output.
func joinElements(elems any, sep string) (string, error) {
if elems == nil {
return "", nil
}

if ss, ok := elems.([]string); ok {
return strings.Join(ss, sep), nil
}

switch rv := reflect.ValueOf(elems); rv.Kind() { //nolint:exhaustive // ignore: too many options to make exhaustive
case reflect.Array, reflect.Slice:
var b strings.Builder
for i := range rv.Len() {
if i > 0 {
b.WriteString(sep)
}
_, _ = fmt.Fprint(&b, rv.Index(i).Interface())
}
return b.String(), nil

case reflect.Map:
var out []string
for _, k := range rv.MapKeys() {
out = append(out, fmt.Sprint(rv.MapIndex(k).Interface()))
}
// Not ideal, but trying to keep a consistent order
sort.Strings(out)
return strings.Join(out, sep), nil

default:
return "", fmt.Errorf("expected slice, got %T", elems)
}
}
90 changes: 90 additions & 0 deletions templates/templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package templates
import (
"bytes"
"testing"
"text/template"

"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"
Expand Down Expand Up @@ -139,3 +140,92 @@ func TestHeaderFunctions(t *testing.T) {
})
}
}

type stringerString string

func (s stringerString) String() string {
return "stringer" + string(s)
}

type stringerAndError string

func (s stringerAndError) String() string {
return "stringer" + string(s)
}

func (s stringerAndError) Error() string {
return "error" + string(s)
}

func TestJoinElements(t *testing.T) {
tests := []struct {
doc string
data any
expOut string
expErr string
}{
{
doc: "nil",
data: nil,
expOut: `output: ""`,
},
{
doc: "non-slice",
data: "hello",
expOut: `output: "`,
expErr: `error calling join: expected slice, got string`,
},
{
doc: "structs",
data: []struct{ A, B string }{{"1", "2"}, {"3", "4"}},
expOut: `output: "{1 2}, {3 4}"`,
},
{
doc: "map with strings",
data: map[string]string{"A": "1", "B": "2", "C": "3"},
expOut: `output: "1, 2, 3"`,
},
{
doc: "map with stringers",
data: map[string]stringerString{"A": "1", "B": "2", "C": "3"},
expOut: `output: "stringer1, stringer2, stringer3"`,
},
{
doc: "map with errors",
data: []stringerAndError{"1", "2", "3"},
expOut: `output: "error1, error2, error3"`,
},
{
doc: "stringers",
data: []stringerString{"1", "2", "3"},
expOut: `output: "stringer1, stringer2, stringer3"`,
},
{
doc: "stringer with errors",
data: []stringerAndError{"1", "2", "3"},
expOut: `output: "error1, error2, error3"`,
},
{
doc: "slice of bools",
data: []bool{true, false, true},
expOut: `output: "true, false, true"`,
},
}

const formatStr = `output: "{{- join . ", " -}}"`
tmpl, err := New("my-template").Funcs(template.FuncMap{"join": joinElements}).Parse(formatStr)
assert.NilError(t, err)

for _, tc := range tests {
t.Run(tc.doc, func(t *testing.T) {
var b bytes.Buffer
err := tmpl.Execute(&b, tc.data)
if tc.expErr != "" {
assert.ErrorContains(t, err, tc.expErr)
} else {
assert.NilError(t, err)
}
assert.Equal(t, b.String(), tc.expOut)
})
}
}
Loading