@@ -30,6 +30,71 @@ func chiURLParamsToMap(chiCtx *chi.Context) map[string]string {
3030 return util .Iif (len (m ) == 0 , nil , m )
3131}
3232
33+ type testResult struct {
34+ method string
35+ pathParams map [string ]string
36+ handlerMarks []string
37+ chiRoutePattern * string
38+ }
39+
40+ type testRecorder struct {
41+ res testResult
42+ }
43+
44+ func (r * testRecorder ) reset () {
45+ r .res = testResult {}
46+ }
47+
48+ func (r * testRecorder ) handle (optMark ... string ) func (resp http.ResponseWriter , req * http.Request ) {
49+ mark := util .OptionalArg (optMark , "" )
50+ return func (resp http.ResponseWriter , req * http.Request ) {
51+ chiCtx := chi .RouteContext (req .Context ())
52+ r .res .method = req .Method
53+ r .res .pathParams = chiURLParamsToMap (chiCtx )
54+ r .res .chiRoutePattern = new (chiCtx.RoutePattern ())
55+ if mark != "" {
56+ r .res .handlerMarks = append (r .res .handlerMarks , mark )
57+ }
58+ }
59+ }
60+
61+ func (r * testRecorder ) provider (optMark ... string ) func (next http.Handler ) http.Handler {
62+ return func (next http.Handler ) http.Handler {
63+ return http .HandlerFunc (func (resp http.ResponseWriter , req * http.Request ) {
64+ r .handle (optMark ... )(resp , req )
65+ next .ServeHTTP (resp , req )
66+ })
67+ }
68+ }
69+
70+ func (r * testRecorder ) stop (optMark ... string ) func (resp http.ResponseWriter , req * http.Request ) {
71+ mark := util .OptionalArg (optMark , "" )
72+ return func (resp http.ResponseWriter , req * http.Request ) {
73+ if stop := req .FormValue ("stop" ); stop != "" && (mark == "" || mark == stop ) {
74+ r .handle (stop )(resp , req )
75+ resp .WriteHeader (http .StatusOK )
76+ } else if mark != "" {
77+ r .res .handlerMarks = append (r .res .handlerMarks , mark )
78+ }
79+ }
80+ }
81+
82+ func (r * testRecorder ) test (t * testing.T , rt * Router , methodPath string , expected testResult ) {
83+ r .reset ()
84+ methodPathFields := strings .Fields (methodPath )
85+ req , err := http .NewRequest (methodPathFields [0 ], methodPathFields [1 ], nil )
86+ assert .NoError (t , err )
87+
88+ buff := & bytes.Buffer {}
89+ httpRecorder := httptest .NewRecorder ()
90+ httpRecorder .Body = buff
91+ rt .ServeHTTP (httpRecorder , req )
92+ if expected .chiRoutePattern == nil {
93+ r .res .chiRoutePattern = nil
94+ }
95+ assert .Equal (t , expected , r .res )
96+ }
97+
3398func TestPathProcessor (t * testing.T ) {
3499 testProcess := func (pattern , uri string , expectedPathParams map [string ]string ) {
35100 chiCtx := chi .NewRouteContext ()
@@ -51,42 +116,10 @@ func TestPathProcessor(t *testing.T) {
51116}
52117
53118func TestRouter (t * testing.T ) {
54- buff := & bytes.Buffer {}
55- recorder := httptest .NewRecorder ()
56- recorder .Body = buff
57-
58- type resultStruct struct {
59- method string
60- pathParams map [string ]string
61- handlerMarks []string
62- chiRoutePattern * string
63- }
64-
65- var res resultStruct
66- h := func (optMark ... string ) func (resp http.ResponseWriter , req * http.Request ) {
67- mark := util .OptionalArg (optMark , "" )
68- return func (resp http.ResponseWriter , req * http.Request ) {
69- chiCtx := chi .RouteContext (req .Context ())
70- res .method = req .Method
71- res .pathParams = chiURLParamsToMap (chiCtx )
72- res .chiRoutePattern = new (chiCtx.RoutePattern ())
73- if mark != "" {
74- res .handlerMarks = append (res .handlerMarks , mark )
75- }
76- }
77- }
78-
79- stopMark := func (optMark ... string ) func (resp http.ResponseWriter , req * http.Request ) {
80- mark := util .OptionalArg (optMark , "" )
81- return func (resp http.ResponseWriter , req * http.Request ) {
82- if stop := req .FormValue ("stop" ); stop != "" && (mark == "" || mark == stop ) {
83- h (stop )(resp , req )
84- resp .WriteHeader (http .StatusOK )
85- } else if mark != "" {
86- res .handlerMarks = append (res .handlerMarks , mark )
87- }
88- }
89- }
119+ type resultStruct = testResult
120+ resRecorder := & testRecorder {}
121+ h := resRecorder .handle
122+ stopMark := resRecorder .stop
90123
91124 r := NewRouter ()
92125 r .NotFound (h ("not-found:/" ))
@@ -123,15 +156,7 @@ func TestRouter(t *testing.T) {
123156
124157 testRoute := func (t * testing.T , methodPath string , expected resultStruct ) {
125158 t .Run (methodPath , func (t * testing.T ) {
126- res = resultStruct {}
127- methodPathFields := strings .Fields (methodPath )
128- req , err := http .NewRequest (methodPathFields [0 ], methodPathFields [1 ], nil )
129- assert .NoError (t , err )
130- r .ServeHTTP (recorder , req )
131- if expected .chiRoutePattern == nil {
132- res .chiRoutePattern = nil
133- }
134- assert .Equal (t , expected , res )
159+ resRecorder .test (t , r , methodPath , expected )
135160 })
136161 }
137162
@@ -273,3 +298,27 @@ func TestRouteNormalizePath(t *testing.T) {
273298 testPath ("/v2/" , paths {EscapedPath : "/v2" , RawPath : "/v2" , Path : "/v2" })
274299 testPath ("/v2/%2f" , paths {EscapedPath : "/v2/%2f" , RawPath : "/v2/%2f" , Path : "/v2//" })
275300}
301+
302+ func TestPreMiddlewareProvider (t * testing.T ) {
303+ resRecorder := & testRecorder {}
304+ h := resRecorder .handle
305+ p := resRecorder .provider
306+
307+ r := NewRouter ()
308+ r .Use (h ("static" ))
309+ r .Get ("/a/1" , h ("mid" ), PreMiddlewareProvider (p ("pre" )), h ("end1" ))
310+
311+ sub := NewRouter ()
312+ sub .Use (h ("sub" ))
313+ sub .Get ("/2" , h ("mid" ), PreMiddlewareProvider (p ("pre" )), h ("end2" ))
314+ r .Mount ("/a" , sub )
315+
316+ resRecorder .test (t , r , "GET /a/1" , testResult {
317+ method : "GET" ,
318+ handlerMarks : []string {"static" , "pre" , "mid" , "end1" },
319+ })
320+ resRecorder .test (t , r , "GET /a/2" , testResult {
321+ method : "GET" ,
322+ handlerMarks : []string {"static" , "sub" , "pre" , "mid" , "end2" },
323+ })
324+ }
0 commit comments