@@ -18,6 +18,13 @@ import (
1818 "github.com/go-chi/chi/v5"
1919)
2020
21+ // PreMiddlewareProvider is a special middleware provider which will be executed
22+ // before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
23+ // A route can do something (e.g.: set middleware options) at the place where it is declared,
24+ // and the code will be executed before other middlewares which are added before the declaration.
25+ // Use cases: mark a route with some meta info, set some options for middlewares, etc.
26+ type PreMiddlewareProvider func (next http.Handler ) http.Handler
27+
2128// Bind binding an obj to a handler's context data
2229func Bind [T any ](_ T ) http.HandlerFunc {
2330 return func (resp http.ResponseWriter , req * http.Request ) {
@@ -41,7 +48,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any {
4148
4249// Router defines a route based on chi's router
4350type Router struct {
44- chiRouter * chi.Mux
51+ chiRouter * chi.Mux
52+
53+ afterRouting []any
54+
4555 curGroupPrefix string
4656 curMiddlewares []any
4757}
@@ -52,16 +62,23 @@ func NewRouter() *Router {
5262 return & Router {chiRouter : r }
5363}
5464
55- // Use supports two middlewares
56- func (r * Router ) Use (middlewares ... any ) {
65+ // BeforeRouting adds middlewares which will be executed before the request path gets routed
66+ // It should only be used for framework-level global middlewares when it needs to change request method & path.
67+ func (r * Router ) BeforeRouting (middlewares ... any ) {
5768 for _ , m := range middlewares {
5869 if ! isNilOrFuncNil (m ) {
5970 r .chiRouter .Use (toHandlerProvider (m ))
6071 }
6172 }
6273}
6374
64- // Group mounts a sub-Router along a `pattern` string.
75+ // AfterRouting adds middlewares which will be executed after the request path gets routed
76+ // It can see the routed path and resolved path parameters
77+ func (r * Router ) AfterRouting (middlewares ... any ) {
78+ r .afterRouting = append (r .afterRouting , middlewares ... )
79+ }
80+
81+ // Group mounts a sub-router along a "pattern" string.
6582func (r * Router ) Group (pattern string , fn func (), middlewares ... any ) {
6683 previousGroupPrefix := r .curGroupPrefix
6784 previousMiddlewares := r .curMiddlewares
@@ -93,36 +110,54 @@ func isNilOrFuncNil(v any) bool {
93110 return r .Kind () == reflect .Func && r .IsNil ()
94111}
95112
96- func wrapMiddlewareAndHandler (curMiddlewares , h []any ) ([]func (http.Handler ) http.Handler , http.HandlerFunc ) {
97- handlerProviders := make ([]func (http.Handler ) http.Handler , 0 , len (curMiddlewares )+ len (h )+ 1 )
98- for _ , m := range curMiddlewares {
99- if ! isNilOrFuncNil (m ) {
100- handlerProviders = append (handlerProviders , toHandlerProvider (m ))
113+ func wrapMiddlewareAppendPre (all []middlewareProvider , middlewares []any ) []middlewareProvider {
114+ for _ , m := range middlewares {
115+ if h , ok := m .(PreMiddlewareProvider ); ok && h != nil {
116+ all = append (all , toHandlerProvider (middlewareProvider (h )))
117+ }
118+ }
119+ return all
120+ }
121+
122+ func wrapMiddlewareAppendNormal (all []middlewareProvider , middlewares []any ) []middlewareProvider {
123+ for _ , m := range middlewares {
124+ if _ , ok := m .(PreMiddlewareProvider ); ! ok && ! isNilOrFuncNil (m ) {
125+ all = append (all , toHandlerProvider (m ))
101126 }
102127 }
128+ return all
129+ }
130+
131+ func wrapMiddlewareAndHandler (useMiddlewares , curMiddlewares , h []any ) (_ []middlewareProvider , _ http.HandlerFunc , hasPreMiddlewares bool ) {
103132 if len (h ) == 0 {
104133 panic ("no endpoint handler provided" )
105134 }
106- for i , m := range h {
107- if ! isNilOrFuncNil (m ) {
108- handlerProviders = append (handlerProviders , toHandlerProvider (m ))
109- } else if i == len (h )- 1 {
110- panic ("endpoint handler can't be nil" )
111- }
135+ if isNilOrFuncNil (h [len (h )- 1 ]) {
136+ panic ("endpoint handler can't be nil" )
112137 }
138+
139+ handlerProviders := make ([]middlewareProvider , 0 , len (useMiddlewares )+ len (curMiddlewares )+ len (h )+ 1 )
140+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , useMiddlewares )
141+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , curMiddlewares )
142+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , h )
143+ hasPreMiddlewares = len (handlerProviders ) > 0
144+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , useMiddlewares )
145+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , curMiddlewares )
146+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , h )
147+
113148 middlewares := handlerProviders [:len (handlerProviders )- 1 ]
114149 handlerFunc := handlerProviders [len (handlerProviders )- 1 ](nil ).ServeHTTP
115150 mockPoint := RouterMockPoint (MockAfterMiddlewares )
116151 if mockPoint != nil {
117152 middlewares = append (middlewares , mockPoint )
118153 }
119- return middlewares , handlerFunc
154+ return middlewares , handlerFunc , hasPreMiddlewares
120155}
121156
122157// Methods adds the same handlers for multiple http "methods" (separated by ",").
123158// If any method is invalid, the lower level router will panic.
124159func (r * Router ) Methods (methods , pattern string , h ... any ) {
125- middlewares , handlerFunc := wrapMiddlewareAndHandler (r .curMiddlewares , h )
160+ middlewares , handlerFunc , _ := wrapMiddlewareAndHandler (r . afterRouting , r .curMiddlewares , h )
126161 fullPattern := r .getPattern (pattern )
127162 if strings .Contains (methods , "," ) {
128163 methods := strings .SplitSeq (methods , "," )
@@ -134,15 +169,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
134169 }
135170}
136171
137- // Mount attaches another Router along . /pattern/*
172+ // Mount attaches another Router along " /pattern/*"
138173func (r * Router ) Mount (pattern string , subRouter * Router ) {
139- subRouter .Use (r .curMiddlewares ... )
140- r .chiRouter .Mount (r .getPattern (pattern ), subRouter .chiRouter )
174+ handlerProviders := make ([]middlewareProvider , 0 , len (r .afterRouting )+ len (r .curMiddlewares ))
175+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , r .afterRouting )
176+ handlerProviders = wrapMiddlewareAppendPre (handlerProviders , r .curMiddlewares )
177+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , r .afterRouting )
178+ handlerProviders = wrapMiddlewareAppendNormal (handlerProviders , r .curMiddlewares )
179+ r .chiRouter .With (handlerProviders ... ).Mount (r .getPattern (pattern ), subRouter .chiRouter )
141180}
142181
143182// Any delegate requests for all methods
144183func (r * Router ) Any (pattern string , h ... any ) {
145- middlewares , handlerFunc := wrapMiddlewareAndHandler (r .curMiddlewares , h )
184+ middlewares , handlerFunc , _ := wrapMiddlewareAndHandler (r . afterRouting , r .curMiddlewares , h )
146185 r .chiRouter .With (middlewares ... ).HandleFunc (r .getPattern (pattern ), handlerFunc )
147186}
148187
@@ -178,12 +217,16 @@ func (r *Router) Patch(pattern string, h ...any) {
178217
179218// ServeHTTP implements http.Handler
180219func (r * Router ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
220+ // TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
181221 r .normalizeRequestPath (w , req , r .chiRouter )
182222}
183223
184224// NotFound defines a handler to respond whenever a route could not be found.
185225func (r * Router ) NotFound (h http.HandlerFunc ) {
186- r .chiRouter .NotFound (h )
226+ middlewares , handlerFunc , _ := wrapMiddlewareAndHandler (r .afterRouting , r .curMiddlewares , []any {h })
227+ r .chiRouter .NotFound (func (w http.ResponseWriter , r * http.Request ) {
228+ executeMiddlewaresHandler (w , r , middlewares , handlerFunc )
229+ })
187230}
188231
189232func (r * Router ) normalizeRequestPath (resp http.ResponseWriter , req * http.Request , next http.Handler ) {
0 commit comments