@@ -18,6 +18,11 @@ import (
1818 "github.com/go-chi/chi/v5"
1919)
2020
21+ // PreMiddlewareProvider is a special middleware provider which will be executed before other middlewares on the same "router" level.
22+ // then a route can set middleware options at the place where it is declared, and the options will be available for other middlewares which are added before the declaration.
23+ // It works for middlewares added by AfterRouting, Group, Methods, Any. It doesn't affect the middlewares added by BeforeRouting or cross-mount-level.
24+ type PreMiddlewareProvider func (next http.Handler ) http.Handler
25+
2126// Bind binding an obj to a handler's context data
2227func Bind [T any ](_ T ) http.HandlerFunc {
2328 return func (resp http.ResponseWriter , req * http.Request ) {
@@ -41,7 +46,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any {
4146
4247// Router defines a route based on chi's router
4348type Router struct {
44- chiRouter * chi.Mux
49+ chiRouter * chi.Mux
50+
51+ afterRouting []any
52+
4553 curGroupPrefix string
4654 curMiddlewares []any
4755}
@@ -52,16 +60,23 @@ func NewRouter() *Router {
5260 return & Router {chiRouter : r }
5361}
5462
55- // Use supports two middlewares
56- func (r * Router ) Use (middlewares ... any ) {
63+ // BeforeRouting adds middlewares which will be executed before the request path gets routed
64+ // It should only be used for framework-level global middlewares when it needs to change request method & path.
65+ func (r * Router ) BeforeRouting (middlewares ... any ) {
5766 for _ , m := range middlewares {
5867 if ! isNilOrFuncNil (m ) {
5968 r .chiRouter .Use (toHandlerProvider (m ))
6069 }
6170 }
6271}
6372
64- // Group mounts a sub-Router along a `pattern` string.
73+ // AfterRouting adds middlewares which will be executed after the request path gets routed
74+ // It can see the routed path and resolved path parameters
75+ func (r * Router ) AfterRouting (middlewares ... any ) {
76+ r .afterRouting = append (r .afterRouting , middlewares ... )
77+ }
78+
79+ // Group mounts a sub-router along a "pattern" string.
6580func (r * Router ) Group (pattern string , fn func (), middlewares ... any ) {
6681 previousGroupPrefix := r .curGroupPrefix
6782 previousMiddlewares := r .curMiddlewares
@@ -93,36 +108,54 @@ func isNilOrFuncNil(v any) bool {
93108 return r .Kind () == reflect .Func && r .IsNil ()
94109}
95110
96- func wrapMiddlewareAndHandler (curMiddlewares , h []any ) ([]func (http.Handler ) http.Handler , http.HandlerFunc ) {
97- handlerProviders := make ([]func (http.Handler ) http.Handler , 0 , len (curMiddlewares )+ len (h )+ 1 )
98- for _ , m := range curMiddlewares {
99- if ! isNilOrFuncNil (m ) {
100- handlerProviders = append (handlerProviders , toHandlerProvider (m ))
111+ func wrapMiddlewareAppendPre (all []middlewareProvider , middlewares []any ) []middlewareProvider {
112+ for _ , m := range middlewares {
113+ if h , ok := m .(PreMiddlewareProvider ); ok && h != nil {
114+ all = append (all , toHandlerProvider (middlewareProvider (h )))
115+ }
116+ }
117+ return all
118+ }
119+
120+ func wrapMiddlewareAppendNormal (all []middlewareProvider , middlewares []any ) []middlewareProvider {
121+ for _ , m := range middlewares {
122+ if _ , ok := m .(PreMiddlewareProvider ); ! ok && ! isNilOrFuncNil (m ) {
123+ all = append (all , toHandlerProvider (m ))
101124 }
102125 }
126+ return all
127+ }
128+
129+ func wrapMiddlewareAndHandler (useMiddlewares , curMiddlewares , h []any ) (_ []middlewareProvider , _ http.HandlerFunc , hasPreMiddlewares bool ) {
103130 if len (h ) == 0 {
104131 panic ("no endpoint handler provided" )
105132 }
106- for i , m := range h {
107- if ! isNilOrFuncNil (m ) {
108- handlerProviders = append (handlerProviders , toHandlerProvider (m ))
109- } else if i == len (h )- 1 {
110- panic ("endpoint handler can't be nil" )
111- }
133+ if isNilOrFuncNil (h [len (h )- 1 ]) {
134+ panic ("endpoint handler can't be nil" )
112135 }
136+
137+ handlerProviders := make ([]middlewareProvider , 0 , len (useMiddlewares )+ len (curMiddlewares )+ len (h )+ 1 )
138+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , useMiddlewares )
139+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , curMiddlewares )
140+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , h )
141+ hasPreMiddlewares = len (handlerProviders ) > 0
142+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , useMiddlewares )
143+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , curMiddlewares )
144+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , h )
145+
113146 middlewares := handlerProviders [:len (handlerProviders )- 1 ]
114147 handlerFunc := handlerProviders [len (handlerProviders )- 1 ](nil ).ServeHTTP
115148 mockPoint := RouterMockPoint (MockAfterMiddlewares )
116149 if mockPoint != nil {
117150 middlewares = append (middlewares , mockPoint )
118151 }
119- return middlewares , handlerFunc
152+ return middlewares , handlerFunc , hasPreMiddlewares
120153}
121154
122155// Methods adds the same handlers for multiple http "methods" (separated by ",").
123156// If any method is invalid, the lower level router will panic.
124157func (r * Router ) Methods (methods , pattern string , h ... any ) {
125- middlewares , handlerFunc := wrapMiddlewareAndHandler (r .curMiddlewares , h )
158+ middlewares , handlerFunc , _ := wrapMiddlewareAndHandler (r . afterRouting , r .curMiddlewares , h )
126159 fullPattern := r .getPattern (pattern )
127160 if strings .Contains (methods , "," ) {
128161 methods := strings .SplitSeq (methods , "," )
@@ -134,15 +167,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
134167 }
135168}
136169
137- // Mount attaches another Router along . /pattern/*
170+ // Mount attaches another Router along " /pattern/*"
138171func (r * Router ) Mount (pattern string , subRouter * Router ) {
139- subRouter .Use (r .curMiddlewares ... )
140- r .chiRouter .Mount (r .getPattern (pattern ), subRouter .chiRouter )
172+ handlerProviders := make ([]middlewareProvider , 0 , len (r .afterRouting )+ len (r .curMiddlewares ))
173+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , r .afterRouting )
174+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , r .curMiddlewares )
175+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , r .afterRouting )
176+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , r .curMiddlewares )
177+ r .chiRouter .With (handlerProviders ... ).Mount (r .getPattern (pattern ), subRouter .chiRouter )
141178}
142179
143180// Any delegate requests for all methods
144181func (r * Router ) Any (pattern string , h ... any ) {
145- middlewares , handlerFunc := wrapMiddlewareAndHandler (r .curMiddlewares , h )
182+ middlewares , handlerFunc , _ := wrapMiddlewareAndHandler (r . afterRouting , r .curMiddlewares , h )
146183 r .chiRouter .With (middlewares ... ).HandleFunc (r .getPattern (pattern ), handlerFunc )
147184}
148185
@@ -178,12 +215,16 @@ func (r *Router) Patch(pattern string, h ...any) {
178215
179216// ServeHTTP implements http.Handler
180217func (r * Router ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
218+ // TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
181219 r .normalizeRequestPath (w , req , r .chiRouter )
182220}
183221
184222// NotFound defines a handler to respond whenever a route could not be found.
185223func (r * Router ) NotFound (h http.HandlerFunc ) {
186- r .chiRouter .NotFound (h )
224+ middlewares , handlerFunc , _ := wrapMiddlewareAndHandler (r .afterRouting , r .curMiddlewares , []any {h })
225+ r .chiRouter .NotFound (func (w http.ResponseWriter , r * http.Request ) {
226+ executeMiddlewaresHandler (w , r , middlewares , handlerFunc )
227+ })
187228}
188229
189230func (r * Router ) normalizeRequestPath (resp http.ResponseWriter , req * http.Request , next http.Handler ) {
0 commit comments