Skip to content

Commit 279efff

Browse files
committed
fix
1 parent f3bdcc5 commit 279efff

23 files changed

Lines changed: 330 additions & 436 deletions

File tree

modules/web/handler.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
7070

7171
func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
7272
defer func() {
73-
if err := recover(); err != nil {
73+
if recovered := recover(); recovered != nil {
74+
err := fmt.Errorf("%v\n%s", recovered, log.Stack(2))
7475
log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
7576
panic(err)
7677
}
@@ -117,7 +118,17 @@ func hasResponseBeenWritten(argsIn []reflect.Value) bool {
117118
return false
118119
}
119120

120-
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler {
121+
type middlewareProvider = func(next http.Handler) http.Handler
122+
123+
func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
124+
handler := endpoint
125+
for i := len(middlewares) - 1; i >= 0; i-- {
126+
handler = middlewares[i](handler).ServeHTTP
127+
}
128+
handler(w, r)
129+
}
130+
131+
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
121132
return func(next http.Handler) http.Handler {
122133
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
123134
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
@@ -129,14 +140,14 @@ func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo
129140

130141
// toHandlerProvider converts a handler to a handler provider
131142
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
132-
func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
143+
func toHandlerProvider(handler any) middlewareProvider {
133144
funcInfo := routing.GetFuncInfo(handler)
134145
fn := reflect.ValueOf(handler)
135146
if fn.Type().Kind() != reflect.Func {
136147
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
137148
}
138149

139-
if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
150+
if hp, ok := handler.(middlewareProvider); ok {
140151
return wrapHandlerProvider(hp, funcInfo)
141152
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
142153
return wrapHandlerProvider(hp, funcInfo)

modules/web/router.go

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ import (
1818
"github.com/go-chi/chi/v5"
1919
)
2020

21+
// PreMiddlewareProvider is a special middleware provider which will be executed
22+
// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
23+
// A route can do something (e.g.: set middleware options) at the place where it is declared,
24+
// and the code will be executed before other middlewares which are added before the declaration.
25+
// Use cases: mark a route with some meta info, set some options for middlewares, etc.
26+
type PreMiddlewareProvider func(next http.Handler) http.Handler
27+
2128
// Bind binding an obj to a handler's context data
2229
func Bind[T any](_ T) http.HandlerFunc {
2330
return func(resp http.ResponseWriter, req *http.Request) {
@@ -41,7 +48,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any {
4148

4249
// Router defines a route based on chi's router
4350
type Router struct {
44-
chiRouter *chi.Mux
51+
chiRouter *chi.Mux
52+
53+
afterRouting []any
54+
4555
curGroupPrefix string
4656
curMiddlewares []any
4757
}
@@ -52,16 +62,23 @@ func NewRouter() *Router {
5262
return &Router{chiRouter: r}
5363
}
5464

55-
// Use supports two middlewares
56-
func (r *Router) Use(middlewares ...any) {
65+
// BeforeRouting adds middlewares which will be executed before the request path gets routed
66+
// It should only be used for framework-level global middlewares when it needs to change request method & path.
67+
func (r *Router) BeforeRouting(middlewares ...any) {
5768
for _, m := range middlewares {
5869
if !isNilOrFuncNil(m) {
5970
r.chiRouter.Use(toHandlerProvider(m))
6071
}
6172
}
6273
}
6374

64-
// Group mounts a sub-Router along a `pattern` string.
75+
// AfterRouting adds middlewares which will be executed after the request path gets routed
76+
// It can see the routed path and resolved path parameters
77+
func (r *Router) AfterRouting(middlewares ...any) {
78+
r.afterRouting = append(r.afterRouting, middlewares...)
79+
}
80+
81+
// Group mounts a sub-router along a "pattern" string.
6582
func (r *Router) Group(pattern string, fn func(), middlewares ...any) {
6683
previousGroupPrefix := r.curGroupPrefix
6784
previousMiddlewares := r.curMiddlewares
@@ -93,36 +110,54 @@ func isNilOrFuncNil(v any) bool {
93110
return r.Kind() == reflect.Func && r.IsNil()
94111
}
95112

96-
func wrapMiddlewareAndHandler(curMiddlewares, h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
97-
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(curMiddlewares)+len(h)+1)
98-
for _, m := range curMiddlewares {
99-
if !isNilOrFuncNil(m) {
100-
handlerProviders = append(handlerProviders, toHandlerProvider(m))
113+
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
114+
for _, m := range middlewares {
115+
if h, ok := m.(PreMiddlewareProvider); ok && h != nil {
116+
all = append(all, toHandlerProvider(middlewareProvider(h)))
117+
}
118+
}
119+
return all
120+
}
121+
122+
func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
123+
for _, m := range middlewares {
124+
if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
125+
all = append(all, toHandlerProvider(m))
101126
}
102127
}
128+
return all
129+
}
130+
131+
func wrapMiddlewareAndHandler(useMiddlewares, curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) {
103132
if len(h) == 0 {
104133
panic("no endpoint handler provided")
105134
}
106-
for i, m := range h {
107-
if !isNilOrFuncNil(m) {
108-
handlerProviders = append(handlerProviders, toHandlerProvider(m))
109-
} else if i == len(h)-1 {
110-
panic("endpoint handler can't be nil")
111-
}
135+
if isNilOrFuncNil(h[len(h)-1]) {
136+
panic("endpoint handler can't be nil")
112137
}
138+
139+
handlerProviders := make([]middlewareProvider, 0, len(useMiddlewares)+len(curMiddlewares)+len(h)+1)
140+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, useMiddlewares)
141+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares)
142+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h)
143+
hasPreMiddlewares = len(handlerProviders) > 0
144+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, useMiddlewares)
145+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares)
146+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h)
147+
113148
middlewares := handlerProviders[:len(handlerProviders)-1]
114149
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
115150
mockPoint := RouterMockPoint(MockAfterMiddlewares)
116151
if mockPoint != nil {
117152
middlewares = append(middlewares, mockPoint)
118153
}
119-
return middlewares, handlerFunc
154+
return middlewares, handlerFunc, hasPreMiddlewares
120155
}
121156

122157
// Methods adds the same handlers for multiple http "methods" (separated by ",").
123158
// If any method is invalid, the lower level router will panic.
124159
func (r *Router) Methods(methods, pattern string, h ...any) {
125-
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
160+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
126161
fullPattern := r.getPattern(pattern)
127162
if strings.Contains(methods, ",") {
128163
methods := strings.SplitSeq(methods, ",")
@@ -134,15 +169,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
134169
}
135170
}
136171

137-
// Mount attaches another Router along ./pattern/*
172+
// Mount attaches another Router along "/pattern/*"
138173
func (r *Router) Mount(pattern string, subRouter *Router) {
139-
subRouter.Use(r.curMiddlewares...)
140-
r.chiRouter.Mount(r.getPattern(pattern), subRouter.chiRouter)
174+
handlerProviders := make([]middlewareProvider, 0, len(r.afterRouting)+len(r.curMiddlewares))
175+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.afterRouting)
176+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.curMiddlewares)
177+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.afterRouting)
178+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.curMiddlewares)
179+
r.chiRouter.With(handlerProviders...).Mount(r.getPattern(pattern), subRouter.chiRouter)
141180
}
142181

143182
// Any delegate requests for all methods
144183
func (r *Router) Any(pattern string, h ...any) {
145-
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
184+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
146185
r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
147186
}
148187

@@ -178,12 +217,16 @@ func (r *Router) Patch(pattern string, h ...any) {
178217

179218
// ServeHTTP implements http.Handler
180219
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
220+
// TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
181221
r.normalizeRequestPath(w, req, r.chiRouter)
182222
}
183223

184224
// NotFound defines a handler to respond whenever a route could not be found.
185225
func (r *Router) NotFound(h http.HandlerFunc) {
186-
r.chiRouter.NotFound(h)
226+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, []any{h})
227+
r.chiRouter.NotFound(func(w http.ResponseWriter, r *http.Request) {
228+
executeMiddlewaresHandler(w, r, middlewares, handlerFunc)
229+
})
187230
}
188231

189232
func (r *Router) normalizeRequestPath(resp http.ResponseWriter, req *http.Request, next http.Handler) {

modules/web/router_path.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request)
2727
for _, m := range g.matchers {
2828
if m.matchPath(chiCtx, path) {
2929
chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern)
30-
handler := m.handlerFunc
31-
for i := len(m.middlewares) - 1; i >= 0; i-- {
32-
handler = m.middlewares[i](handler).ServeHTTP
33-
}
34-
handler(resp, req)
30+
executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc)
3531
return
3632
}
3733
}
@@ -67,7 +63,7 @@ type routerPathMatcher struct {
6763
pattern string
6864
re *regexp.Regexp
6965
params []routerPathParam
70-
middlewares []func(http.Handler) http.Handler
66+
middlewares []middlewareProvider
7167
handlerFunc http.HandlerFunc
7268
}
7369

@@ -111,7 +107,10 @@ func isValidMethod(name string) bool {
111107
}
112108

113109
func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher {
114-
middlewares, handlerFunc := wrapMiddlewareAndHandler(patternRegexp.middlewares, h)
110+
middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(nil, patternRegexp.middlewares, h)
111+
if hasPreMiddlewares {
112+
panic("pre-middlewares are not supported in router path matcher")
113+
}
115114
p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
116115
for method := range strings.SplitSeq(methods, ",") {
117116
method = strings.TrimSpace(method)

0 commit comments

Comments
 (0)