Skip to content

Commit e7c8e18

Browse files
dimfeldstuartclanJacob Stuart
authored
Put the route path in the request context (#75)
* Put full route path in the context * Update to include routes with no params (#74) Co-authored-by: Jacob Stuart <[email protected]> * Combine context route path and params into a single interface Co-authored-by: stuartclan <[email protected]> Co-authored-by: Jacob Stuart <[email protected]>
1 parent 8cec559 commit e7c8e18

File tree

3 files changed

+195
-87
lines changed

3 files changed

+195
-87
lines changed

README.md

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,15 @@ group.GET("/v1/:id", func(w http.ResponseWriter, r *http.Request, params map[str
3636
// UsingContext returns a version of the router or group with context support.
3737
ctxGroup := group.UsingContext() // sibling to 'group' node in tree
3838
ctxGroup.GET("/v2/:id", func(w http.ResponseWriter, r *http.Request) {
39-
params := httptreemux.ContextParams(r.Context())
39+
ctxData := httptreemux.ContextData(r.Context())
40+
params := ctxData.Params()
4041
id := params["id"]
41-
fmt.Fprintf(w, "GET /api/v2/%s", id)
42+
43+
// Useful for middleware to see which route was hit without dealing with wildcards
44+
routePath := ctxData.Route()
45+
46+
// Prints GET /api/v2/:id id=...
47+
fmt.Fprintf(w, "GET %s id=%s", routePath, id)
4248
})
4349

4450
http.ListenAndServe(":8080", router)
@@ -58,9 +64,15 @@ router.GET("/:page", func(w http.ResponseWriter, r *http.Request) {
5864

5965
group := router.NewGroup("/api")
6066
group.GET("/v1/:id", func(w http.ResponseWriter, r *http.Request) {
61-
params := httptreemux.ContextParams(r.Context())
67+
ctxData := httptreemux.ContextData(r.Context())
68+
params := ctxData.Params()
6269
id := params["id"]
63-
fmt.Fprintf(w, "GET /api/v1/%s", id)
70+
71+
// Useful for middleware to see which route was hit without dealing with wildcards
72+
routePath := ctxData.Route()
73+
74+
// Prints GET /api/v1/:id id=...
75+
fmt.Fprintf(w, "GET %s id=%s", routePath, id)
6476
})
6577

6678
http.ListenAndServe(":8080", router)

context.go

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,27 @@ func (cg *ContextGroup) NewGroup(path string) *ContextGroup {
5858
// Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc.
5959
// Any parameters from the request URL are stored in a map[string]string in the request's context.
6060
func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) {
61+
fullPath := cg.group.path + path
6162
cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
62-
if params != nil {
63-
r = r.WithContext(AddParamsToContext(r.Context(), params))
63+
routeData := contextData{
64+
route: fullPath,
65+
params: params,
6466
}
67+
r = r.WithContext(AddRouteDataToContext(r.Context(), routeData))
6568
handler(w, r)
6669
})
6770
}
6871

6972
// Handler allows handling HTTP requests via an http.Handler interface, as opposed to an httptreemux.HandlerFunc.
7073
// Any parameters from the request URL are stored in a map[string]string in the request's context.
7174
func (cg *ContextGroup) Handler(method, path string, handler http.Handler) {
75+
fullPath := cg.group.path + path
7276
cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
73-
if params != nil {
74-
r = r.WithContext(AddParamsToContext(r.Context(), params))
77+
routeData := contextData{
78+
route: fullPath,
79+
params: params,
7580
}
81+
r = r.WithContext(AddRouteDataToContext(r.Context(), routeData))
7682
handler.ServeHTTP(w, r)
7783
})
7884
}
@@ -112,22 +118,61 @@ func (cg *ContextGroup) OPTIONS(path string, handler http.HandlerFunc) {
112118
cg.Handle("OPTIONS", path, handler)
113119
}
114120

121+
type contextData struct {
122+
route string
123+
params map[string]string
124+
}
125+
126+
func (cd contextData) Route() string {
127+
return cd.route
128+
}
129+
130+
func (cd contextData) Params() map[string]string {
131+
if cd.params != nil {
132+
return cd.params
133+
}
134+
return map[string]string{}
135+
}
136+
137+
// ContextData is the information associated with
138+
type ContextRouteData interface {
139+
Route() string
140+
Params() map[string]string
141+
}
142+
115143
// ContextParams returns the params map associated with the given context if one exists. Otherwise, an empty map is returned.
116144
func ContextParams(ctx context.Context) map[string]string {
117-
if p, ok := ctx.Value(paramsContextKey).(map[string]string); ok {
118-
return p
145+
if p, ok := ctx.Value(routeContextKey).(ContextRouteData); ok {
146+
return p.Params()
119147
}
120148
return map[string]string{}
121149
}
122150

151+
// ContextData returns the full route path associated with the given context, without wildcard expansion.
152+
func ContextData(ctx context.Context) ContextRouteData {
153+
if p, ok := ctx.Value(routeContextKey).(ContextRouteData); ok {
154+
return p
155+
}
156+
return nil
157+
}
158+
159+
func AddRouteDataToContext(ctx context.Context, data ContextRouteData) context.Context {
160+
return context.WithValue(ctx, routeContextKey, data)
161+
}
162+
123163
// AddParamsToContext inserts a parameters map into a context using
124-
// the package's internal context key. Clients of this package should
125-
// really only use this for unit tests.
164+
// the package's internal context key. This function is deprecated.
165+
// Use AddRouteDataToContext instead.
126166
func AddParamsToContext(ctx context.Context, params map[string]string) context.Context {
127-
return context.WithValue(ctx, paramsContextKey, params)
167+
data := contextData{
168+
route: "",
169+
params: params,
170+
}
171+
return AddRouteDataToContext(ctx, data)
128172
}
129173

130174
type contextKey int
131175

132-
// paramsContextKey is used to retrieve a path's params map from a request's context.
133-
const paramsContextKey contextKey = 0
176+
// paramsContextKey and routeContextKey are used to retrieve a path's params map and route
177+
// from a request's context.
178+
const routeContextKey contextKey = 0

context_test.go

Lines changed: 123 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package httptreemux
44

55
import (
66
"context"
7+
"fmt"
78
"net/http"
89
"net/http/httptest"
910
"testing"
@@ -23,105 +24,155 @@ type IContextGroup interface {
2324
}
2425

2526
func TestContextParams(t *testing.T) {
26-
m := map[string]string{"id": "123"}
27-
ctx := context.WithValue(context.Background(), paramsContextKey, m)
27+
m := contextData{
28+
params: map[string]string{"id": "123"},
29+
route: "",
30+
}
31+
32+
ctx := context.WithValue(context.Background(), routeContextKey, m)
2833

2934
params := ContextParams(ctx)
3035
if params == nil {
3136
t.Errorf("expected '%#v', but got '%#v'", m, params)
3237
}
3338

3439
if v := params["id"]; v != "123" {
35-
t.Errorf("expected '%s', but got '%#v'", m["id"], params["id"])
40+
t.Errorf("expected '%s', but got '%#v'", m.params["id"], params["id"])
41+
}
42+
}
43+
44+
func TestContextData(t *testing.T) {
45+
p := contextData{
46+
route: "route/path",
47+
params: map[string]string{"id": "123"},
48+
}
49+
50+
ctx := context.WithValue(context.Background(), routeContextKey, p)
51+
52+
ctxData := ContextData(ctx)
53+
pathValue := ctxData.Route()
54+
if pathValue != p.route {
55+
t.Errorf("expected '%s', but got '%s'", p, pathValue)
56+
}
57+
58+
params := ctxData.Params()
59+
if v := params["id"]; v != "123" {
60+
t.Errorf("expected '%s', but got '%#v'", p.params["id"], params["id"])
61+
}
62+
}
63+
64+
func TestContextDataWithEmptyParams(t *testing.T) {
65+
p := contextData{
66+
route: "route/path",
67+
params: nil,
68+
}
69+
70+
ctx := context.WithValue(context.Background(), routeContextKey, p)
71+
params := ContextData(ctx).Params()
72+
if params == nil {
73+
t.Errorf("ContextData.Params should never return nil")
3674
}
3775
}
3876

3977
func TestContextGroupMethods(t *testing.T) {
4078
for _, scenario := range scenarios {
41-
t.Log(scenario.description)
42-
testContextGroupMethods(t, scenario.RequestCreator, true, false)
43-
testContextGroupMethods(t, scenario.RequestCreator, false, false)
44-
testContextGroupMethods(t, scenario.RequestCreator, true, true)
45-
testContextGroupMethods(t, scenario.RequestCreator, false, true)
79+
t.Run(scenario.description, func(t *testing.T) {
80+
testContextGroupMethods(t, scenario.RequestCreator, true, false)
81+
testContextGroupMethods(t, scenario.RequestCreator, false, false)
82+
testContextGroupMethods(t, scenario.RequestCreator, true, true)
83+
testContextGroupMethods(t, scenario.RequestCreator, false, true)
84+
})
4685
}
4786
}
4887

4988
func testContextGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet bool, useContextRouter bool) {
50-
t.Logf("Running test: headCanUseGet %v, useContextRouter %v", headCanUseGet, useContextRouter)
89+
t.Run(fmt.Sprintf("headCanUseGet %v, useContextRouter %v", headCanUseGet, useContextRouter), func(t *testing.T) {
90+
var result string
91+
makeHandler := func(method, expectedRoutePath string, hasParam bool) http.HandlerFunc {
92+
return func(w http.ResponseWriter, r *http.Request) {
93+
result = method
94+
95+
// Test Legacy Accessor
96+
var v string
97+
v, ok := ContextParams(r.Context())["param"]
98+
if hasParam && !ok {
99+
t.Error("missing key 'param' in context from ContextParams")
100+
}
101+
102+
ctxData := ContextData(r.Context())
103+
v, ok = ctxData.Params()["param"]
104+
if hasParam && !ok {
105+
t.Error("missing key 'param' in context from ContextData")
106+
}
107+
108+
routePath := ctxData.Route()
109+
if routePath != expectedRoutePath {
110+
t.Errorf("Expected context to have route path '%s', saw %s", expectedRoutePath, routePath)
111+
}
112+
113+
if headCanUseGet && (method == "GET" || v == "HEAD") {
114+
return
115+
}
116+
if hasParam && v != method {
117+
t.Errorf("invalid key 'param' in context; expected '%s' but got '%s'", method, v)
118+
}
119+
}
120+
}
51121

52-
var result string
53-
makeHandler := func(method string) http.HandlerFunc {
54-
return func(w http.ResponseWriter, r *http.Request) {
55-
result = method
122+
var router http.Handler
123+
var rootGroup IContextGroup
56124

57-
v, ok := ContextParams(r.Context())["param"]
58-
if !ok {
59-
t.Error("missing key 'param' in context")
60-
}
125+
if useContextRouter {
126+
root := NewContextMux()
127+
root.HeadCanUseGet = headCanUseGet
128+
t.Log(root.TreeMux.HeadCanUseGet)
129+
router = root
130+
rootGroup = root
131+
} else {
132+
root := New()
133+
root.HeadCanUseGet = headCanUseGet
134+
router = root
135+
rootGroup = root.UsingContext()
136+
}
61137

62-
if headCanUseGet && (method == "GET" || v == "HEAD") {
63-
return
138+
cg := rootGroup.NewGroup("/base").NewGroup("/user")
139+
cg.GET("/:param", makeHandler("GET", cg.group.path+"/:param", true))
140+
cg.POST("/:param", makeHandler("POST", cg.group.path+"/:param", true))
141+
cg.PATCH("/PATCH", makeHandler("PATCH", cg.group.path+"/PATCH", false))
142+
cg.PUT("/:param", makeHandler("PUT", cg.group.path+"/:param", true))
143+
cg.Handler("DELETE", "/:param", http.HandlerFunc(makeHandler("DELETE", cg.group.path+"/:param", true)))
144+
145+
testMethod := func(method, expect string) {
146+
result = ""
147+
w := httptest.NewRecorder()
148+
r, _ := reqGen(method, "/base/user/"+method, nil)
149+
router.ServeHTTP(w, r)
150+
if expect == "" && w.Code != http.StatusMethodNotAllowed {
151+
t.Errorf("Method %s not expected to match but saw code %d", method, w.Code)
64152
}
65153

66-
if v != method {
67-
t.Errorf("invalid key 'param' in context; expected '%s' but got '%s'", method, v)
154+
if result != expect {
155+
t.Errorf("Method %s got result %s", method, result)
68156
}
69157
}
70-
}
71-
72-
var router http.Handler
73-
var rootGroup IContextGroup
74-
75-
if useContextRouter {
76-
root := NewContextMux()
77-
root.HeadCanUseGet = headCanUseGet
78-
t.Log(root.TreeMux.HeadCanUseGet)
79-
router = root
80-
rootGroup = root
81-
} else {
82-
root := New()
83-
root.HeadCanUseGet = headCanUseGet
84-
router = root
85-
rootGroup = root.UsingContext()
86-
}
87158

88-
cg := rootGroup.NewGroup("/base").NewGroup("/user")
89-
cg.GET("/:param", makeHandler("GET"))
90-
cg.POST("/:param", makeHandler("POST"))
91-
cg.PATCH("/:param", makeHandler("PATCH"))
92-
cg.PUT("/:param", makeHandler("PUT"))
93-
cg.DELETE("/:param", makeHandler("DELETE"))
159+
testMethod("GET", "GET")
160+
testMethod("POST", "POST")
161+
testMethod("PATCH", "PATCH")
162+
testMethod("PUT", "PUT")
163+
testMethod("DELETE", "DELETE")
94164

95-
testMethod := func(method, expect string) {
96-
result = ""
97-
w := httptest.NewRecorder()
98-
r, _ := reqGen(method, "/base/user/"+method, nil)
99-
router.ServeHTTP(w, r)
100-
if expect == "" && w.Code != http.StatusMethodNotAllowed {
101-
t.Errorf("Method %s not expected to match but saw code %d", method, w.Code)
102-
}
103-
104-
if result != expect {
105-
t.Errorf("Method %s got result %s", method, result)
165+
if headCanUseGet {
166+
t.Log("Test implicit HEAD with HeadCanUseGet = true")
167+
testMethod("HEAD", "GET")
168+
} else {
169+
t.Log("Test implicit HEAD with HeadCanUseGet = false")
170+
testMethod("HEAD", "")
106171
}
107-
}
108-
109-
testMethod("GET", "GET")
110-
testMethod("POST", "POST")
111-
testMethod("PATCH", "PATCH")
112-
testMethod("PUT", "PUT")
113-
testMethod("DELETE", "DELETE")
114-
115-
if headCanUseGet {
116-
t.Log("Test implicit HEAD with HeadCanUseGet = true")
117-
testMethod("HEAD", "GET")
118-
} else {
119-
t.Log("Test implicit HEAD with HeadCanUseGet = false")
120-
testMethod("HEAD", "")
121-
}
122172

123-
cg.HEAD("/:param", makeHandler("HEAD"))
124-
testMethod("HEAD", "HEAD")
173+
cg.HEAD("/:param", makeHandler("HEAD", cg.group.path+"/:param", true))
174+
testMethod("HEAD", "HEAD")
175+
})
125176
}
126177

127178
func TestNewContextGroup(t *testing.T) {

0 commit comments

Comments
 (0)