Skip to content

Commit 61f45ae

Browse files
committed
fix
1 parent f3bdcc5 commit 61f45ae

18 files changed

Lines changed: 315 additions & 424 deletions

File tree

modules/web/handler.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,17 @@ func hasResponseBeenWritten(argsIn []reflect.Value) bool {
117117
return false
118118
}
119119

120-
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler {
120+
type middlewareProvider = func(next http.Handler) http.Handler
121+
122+
func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
123+
handler := endpoint
124+
for i := len(middlewares) - 1; i >= 0; i-- {
125+
handler = middlewares[i](handler).ServeHTTP
126+
}
127+
handler(w, r)
128+
}
129+
130+
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
121131
return func(next http.Handler) http.Handler {
122132
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
123133
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
@@ -129,14 +139,14 @@ func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo
129139

130140
// toHandlerProvider converts a handler to a handler provider
131141
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
132-
func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
142+
func toHandlerProvider(handler any) middlewareProvider {
133143
funcInfo := routing.GetFuncInfo(handler)
134144
fn := reflect.ValueOf(handler)
135145
if fn.Type().Kind() != reflect.Func {
136146
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
137147
}
138148

139-
if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
149+
if hp, ok := handler.(middlewareProvider); ok {
140150
return wrapHandlerProvider(hp, funcInfo)
141151
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
142152
return wrapHandlerProvider(hp, funcInfo)

modules/web/router.go

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ import (
1818
"github.com/go-chi/chi/v5"
1919
)
2020

21+
// PreMiddlewareProvider is a special middleware provider which will be executed before other middlewares on the same "router" level.
22+
// then a route can set middleware options at the place where it is declared, and the options will be available for other middlewares which are added before the declaration.
23+
// It works for middlewares added by AfterRouting, Group, Methods, Any. It doesn't affect the middlewares added by BeforeRouting or cross-mount-level.
24+
type PreMiddlewareProvider func(next http.Handler) http.Handler
25+
2126
// Bind binding an obj to a handler's context data
2227
func Bind[T any](_ T) http.HandlerFunc {
2328
return func(resp http.ResponseWriter, req *http.Request) {
@@ -41,7 +46,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any {
4146

4247
// Router defines a route based on chi's router
4348
type Router struct {
44-
chiRouter *chi.Mux
49+
chiRouter *chi.Mux
50+
51+
afterRouting []any
52+
4553
curGroupPrefix string
4654
curMiddlewares []any
4755
}
@@ -52,16 +60,23 @@ func NewRouter() *Router {
5260
return &Router{chiRouter: r}
5361
}
5462

55-
// Use supports two middlewares
56-
func (r *Router) Use(middlewares ...any) {
63+
// BeforeRouting adds middlewares which will be executed before the request path gets routed
64+
// It should only be used for framework-level global middlewares when it needs to change request method & path.
65+
func (r *Router) BeforeRouting(middlewares ...any) {
5766
for _, m := range middlewares {
5867
if !isNilOrFuncNil(m) {
5968
r.chiRouter.Use(toHandlerProvider(m))
6069
}
6170
}
6271
}
6372

64-
// Group mounts a sub-Router along a `pattern` string.
73+
// AfterRouting adds middlewares which will be executed after the request path gets routed
74+
// It can see the routed path and resolved path parameters
75+
func (r *Router) AfterRouting(middlewares ...any) {
76+
r.afterRouting = append(r.afterRouting, middlewares...)
77+
}
78+
79+
// Group mounts a sub-router along a "pattern" string.
6580
func (r *Router) Group(pattern string, fn func(), middlewares ...any) {
6681
previousGroupPrefix := r.curGroupPrefix
6782
previousMiddlewares := r.curMiddlewares
@@ -93,36 +108,54 @@ func isNilOrFuncNil(v any) bool {
93108
return r.Kind() == reflect.Func && r.IsNil()
94109
}
95110

96-
func wrapMiddlewareAndHandler(curMiddlewares, h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
97-
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(curMiddlewares)+len(h)+1)
98-
for _, m := range curMiddlewares {
99-
if !isNilOrFuncNil(m) {
100-
handlerProviders = append(handlerProviders, toHandlerProvider(m))
111+
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
112+
for _, m := range middlewares {
113+
if h, ok := m.(PreMiddlewareProvider); ok && h != nil {
114+
all = append(all, toHandlerProvider(middlewareProvider(h)))
115+
}
116+
}
117+
return all
118+
}
119+
120+
func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
121+
for _, m := range middlewares {
122+
if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
123+
all = append(all, toHandlerProvider(m))
101124
}
102125
}
126+
return all
127+
}
128+
129+
func wrapMiddlewareAndHandler(useMiddlewares, curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) {
103130
if len(h) == 0 {
104131
panic("no endpoint handler provided")
105132
}
106-
for i, m := range h {
107-
if !isNilOrFuncNil(m) {
108-
handlerProviders = append(handlerProviders, toHandlerProvider(m))
109-
} else if i == len(h)-1 {
110-
panic("endpoint handler can't be nil")
111-
}
133+
if isNilOrFuncNil(h[len(h)-1]) {
134+
panic("endpoint handler can't be nil")
112135
}
136+
137+
handlerProviders := make([]middlewareProvider, 0, len(useMiddlewares)+len(curMiddlewares)+len(h)+1)
138+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, useMiddlewares)
139+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares)
140+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h)
141+
hasPreMiddlewares = len(handlerProviders) > 0
142+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, useMiddlewares)
143+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares)
144+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h)
145+
113146
middlewares := handlerProviders[:len(handlerProviders)-1]
114147
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
115148
mockPoint := RouterMockPoint(MockAfterMiddlewares)
116149
if mockPoint != nil {
117150
middlewares = append(middlewares, mockPoint)
118151
}
119-
return middlewares, handlerFunc
152+
return middlewares, handlerFunc, hasPreMiddlewares
120153
}
121154

122155
// Methods adds the same handlers for multiple http "methods" (separated by ",").
123156
// If any method is invalid, the lower level router will panic.
124157
func (r *Router) Methods(methods, pattern string, h ...any) {
125-
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
158+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
126159
fullPattern := r.getPattern(pattern)
127160
if strings.Contains(methods, ",") {
128161
methods := strings.SplitSeq(methods, ",")
@@ -134,15 +167,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
134167
}
135168
}
136169

137-
// Mount attaches another Router along ./pattern/*
170+
// Mount attaches another Router along "/pattern/*"
138171
func (r *Router) Mount(pattern string, subRouter *Router) {
139-
subRouter.Use(r.curMiddlewares...)
140-
r.chiRouter.Mount(r.getPattern(pattern), subRouter.chiRouter)
172+
handlerProviders := make([]middlewareProvider, 0, len(r.afterRouting)+len(r.curMiddlewares))
173+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.afterRouting)
174+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.curMiddlewares)
175+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.afterRouting)
176+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.curMiddlewares)
177+
r.chiRouter.With(handlerProviders...).Mount(r.getPattern(pattern), subRouter.chiRouter)
141178
}
142179

143180
// Any delegate requests for all methods
144181
func (r *Router) Any(pattern string, h ...any) {
145-
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
182+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
146183
r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
147184
}
148185

@@ -178,12 +215,16 @@ func (r *Router) Patch(pattern string, h ...any) {
178215

179216
// ServeHTTP implements http.Handler
180217
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
218+
// TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
181219
r.normalizeRequestPath(w, req, r.chiRouter)
182220
}
183221

184222
// NotFound defines a handler to respond whenever a route could not be found.
185223
func (r *Router) NotFound(h http.HandlerFunc) {
186-
r.chiRouter.NotFound(h)
224+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, []any{h})
225+
r.chiRouter.NotFound(func(w http.ResponseWriter, r *http.Request) {
226+
executeMiddlewaresHandler(w, r, middlewares, handlerFunc)
227+
})
187228
}
188229

189230
func (r *Router) normalizeRequestPath(resp http.ResponseWriter, req *http.Request, next http.Handler) {

modules/web/router_path.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request)
2727
for _, m := range g.matchers {
2828
if m.matchPath(chiCtx, path) {
2929
chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern)
30-
handler := m.handlerFunc
31-
for i := len(m.middlewares) - 1; i >= 0; i-- {
32-
handler = m.middlewares[i](handler).ServeHTTP
33-
}
34-
handler(resp, req)
30+
executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc)
3531
return
3632
}
3733
}
@@ -67,7 +63,7 @@ type routerPathMatcher struct {
6763
pattern string
6864
re *regexp.Regexp
6965
params []routerPathParam
70-
middlewares []func(http.Handler) http.Handler
66+
middlewares []middlewareProvider
7167
handlerFunc http.HandlerFunc
7268
}
7369

@@ -111,7 +107,10 @@ func isValidMethod(name string) bool {
111107
}
112108

113109
func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher {
114-
middlewares, handlerFunc := wrapMiddlewareAndHandler(patternRegexp.middlewares, h)
110+
middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(nil, patternRegexp.middlewares, h)
111+
if hasPreMiddlewares {
112+
panic("pre-middlewares are not supported in router path matcher")
113+
}
115114
p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
116115
for method := range strings.SplitSeq(methods, ",") {
117116
method = strings.TrimSpace(method)

0 commit comments

Comments
 (0)