Skip to content

Commit af4714d

Browse files
committed
fix
1 parent f3bdcc5 commit af4714d

12 files changed

Lines changed: 245 additions & 389 deletions

File tree

modules/web/handler.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,17 @@ func hasResponseBeenWritten(argsIn []reflect.Value) bool {
117117
return false
118118
}
119119

120-
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler {
120+
type middlewareProvider = func(next http.Handler) http.Handler
121+
122+
func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
123+
handler := endpoint
124+
for i := len(middlewares) - 1; i >= 0; i-- {
125+
handler = middlewares[i](handler).ServeHTTP
126+
}
127+
handler(w, r)
128+
}
129+
130+
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
121131
return func(next http.Handler) http.Handler {
122132
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
123133
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
@@ -129,14 +139,14 @@ func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo
129139

130140
// toHandlerProvider converts a handler to a handler provider
131141
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
132-
func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
142+
func toHandlerProvider(handler any) middlewareProvider {
133143
funcInfo := routing.GetFuncInfo(handler)
134144
fn := reflect.ValueOf(handler)
135145
if fn.Type().Kind() != reflect.Func {
136146
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
137147
}
138148

139-
if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
149+
if hp, ok := handler.(middlewareProvider); ok {
140150
return wrapHandlerProvider(hp, funcInfo)
141151
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
142152
return wrapHandlerProvider(hp, funcInfo)

modules/web/router.go

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ import (
1818
"github.com/go-chi/chi/v5"
1919
)
2020

21+
// PreMiddlewareProvider is a special middleware provider which will be executed before other middlewares on the same "router" level.
22+
// then a route can set middleware options at the place where it is declared, and the options will be available for other middlewares which are added before the declaration.
23+
// It works for "dynamic" middlewares (added by Group, Methods, Any). It doesn't work for "static" middlewares (added Use), and it doesn't work across "Mount" levels.
24+
type PreMiddlewareProvider func(next http.Handler) http.Handler
25+
2126
// Bind binding an obj to a handler's context data
2227
func Bind[T any](_ T) http.HandlerFunc {
2328
return func(resp http.ResponseWriter, req *http.Request) {
@@ -93,36 +98,52 @@ func isNilOrFuncNil(v any) bool {
9398
return r.Kind() == reflect.Func && r.IsNil()
9499
}
95100

96-
func wrapMiddlewareAndHandler(curMiddlewares, h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
97-
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(curMiddlewares)+len(h)+1)
98-
for _, m := range curMiddlewares {
99-
if !isNilOrFuncNil(m) {
100-
handlerProviders = append(handlerProviders, toHandlerProvider(m))
101+
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
102+
for _, m := range middlewares {
103+
if h, ok := m.(PreMiddlewareProvider); ok && h != nil {
104+
all = append(all, toHandlerProvider(middlewareProvider(h)))
105+
}
106+
}
107+
return all
108+
}
109+
110+
func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
111+
for _, m := range middlewares {
112+
if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
113+
all = append(all, toHandlerProvider(m))
101114
}
102115
}
116+
return all
117+
}
118+
119+
func wrapMiddlewareAndHandler(curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) {
103120
if len(h) == 0 {
104121
panic("no endpoint handler provided")
105122
}
106-
for i, m := range h {
107-
if !isNilOrFuncNil(m) {
108-
handlerProviders = append(handlerProviders, toHandlerProvider(m))
109-
} else if i == len(h)-1 {
110-
panic("endpoint handler can't be nil")
111-
}
123+
if isNilOrFuncNil(h[len(h)-1]) {
124+
panic("endpoint handler can't be nil")
112125
}
126+
127+
handlerProviders := make([]middlewareProvider, 0, len(curMiddlewares)+len(h)+1)
128+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares)
129+
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h)
130+
hasPreMiddlewares = len(handlerProviders) > 0
131+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares)
132+
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h)
133+
113134
middlewares := handlerProviders[:len(handlerProviders)-1]
114135
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
115136
mockPoint := RouterMockPoint(MockAfterMiddlewares)
116137
if mockPoint != nil {
117138
middlewares = append(middlewares, mockPoint)
118139
}
119-
return middlewares, handlerFunc
140+
return middlewares, handlerFunc, hasPreMiddlewares
120141
}
121142

122143
// Methods adds the same handlers for multiple http "methods" (separated by ",").
123144
// If any method is invalid, the lower level router will panic.
124145
func (r *Router) Methods(methods, pattern string, h ...any) {
125-
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
146+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.curMiddlewares, h)
126147
fullPattern := r.getPattern(pattern)
127148
if strings.Contains(methods, ",") {
128149
methods := strings.SplitSeq(methods, ",")
@@ -134,15 +155,15 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
134155
}
135156
}
136157

137-
// Mount attaches another Router along ./pattern/*
158+
// Mount attaches another Router along "/pattern/*"
138159
func (r *Router) Mount(pattern string, subRouter *Router) {
139160
subRouter.Use(r.curMiddlewares...)
140161
r.chiRouter.Mount(r.getPattern(pattern), subRouter.chiRouter)
141162
}
142163

143164
// Any delegate requests for all methods
144165
func (r *Router) Any(pattern string, h ...any) {
145-
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
166+
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.curMiddlewares, h)
146167
r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
147168
}
148169

@@ -178,6 +199,7 @@ func (r *Router) Patch(pattern string, h ...any) {
178199

179200
// ServeHTTP implements http.Handler
180201
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
202+
// TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
181203
r.normalizeRequestPath(w, req, r.chiRouter)
182204
}
183205

modules/web/router_path.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request)
2727
for _, m := range g.matchers {
2828
if m.matchPath(chiCtx, path) {
2929
chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern)
30-
handler := m.handlerFunc
31-
for i := len(m.middlewares) - 1; i >= 0; i-- {
32-
handler = m.middlewares[i](handler).ServeHTTP
33-
}
34-
handler(resp, req)
30+
executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc)
3531
return
3632
}
3733
}
@@ -67,7 +63,7 @@ type routerPathMatcher struct {
6763
pattern string
6864
re *regexp.Regexp
6965
params []routerPathParam
70-
middlewares []func(http.Handler) http.Handler
66+
middlewares []middlewareProvider
7167
handlerFunc http.HandlerFunc
7268
}
7369

@@ -111,7 +107,10 @@ func isValidMethod(name string) bool {
111107
}
112108

113109
func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher {
114-
middlewares, handlerFunc := wrapMiddlewareAndHandler(patternRegexp.middlewares, h)
110+
middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(patternRegexp.middlewares, h)
111+
if hasPreMiddlewares {
112+
panic("pre-middlewares are not supported in router path matcher")
113+
}
115114
p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
116115
for method := range strings.SplitSeq(methods, ",") {
117116
method = strings.TrimSpace(method)

modules/web/router_test.go

Lines changed: 94 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,71 @@ func chiURLParamsToMap(chiCtx *chi.Context) map[string]string {
3030
return util.Iif(len(m) == 0, nil, m)
3131
}
3232

33+
type testResult struct {
34+
method string
35+
pathParams map[string]string
36+
handlerMarks []string
37+
chiRoutePattern *string
38+
}
39+
40+
type testRecorder struct {
41+
res testResult
42+
}
43+
44+
func (r *testRecorder) reset() {
45+
r.res = testResult{}
46+
}
47+
48+
func (r *testRecorder) handle(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
49+
mark := util.OptionalArg(optMark, "")
50+
return func(resp http.ResponseWriter, req *http.Request) {
51+
chiCtx := chi.RouteContext(req.Context())
52+
r.res.method = req.Method
53+
r.res.pathParams = chiURLParamsToMap(chiCtx)
54+
r.res.chiRoutePattern = new(chiCtx.RoutePattern())
55+
if mark != "" {
56+
r.res.handlerMarks = append(r.res.handlerMarks, mark)
57+
}
58+
}
59+
}
60+
61+
func (r *testRecorder) provider(optMark ...string) func(next http.Handler) http.Handler {
62+
return func(next http.Handler) http.Handler {
63+
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
64+
r.handle(optMark...)(resp, req)
65+
next.ServeHTTP(resp, req)
66+
})
67+
}
68+
}
69+
70+
func (r *testRecorder) stop(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
71+
mark := util.OptionalArg(optMark, "")
72+
return func(resp http.ResponseWriter, req *http.Request) {
73+
if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) {
74+
r.handle(stop)(resp, req)
75+
resp.WriteHeader(http.StatusOK)
76+
} else if mark != "" {
77+
r.res.handlerMarks = append(r.res.handlerMarks, mark)
78+
}
79+
}
80+
}
81+
82+
func (r *testRecorder) test(t *testing.T, rt *Router, methodPath string, expected testResult) {
83+
r.reset()
84+
methodPathFields := strings.Fields(methodPath)
85+
req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil)
86+
assert.NoError(t, err)
87+
88+
buff := &bytes.Buffer{}
89+
httpRecorder := httptest.NewRecorder()
90+
httpRecorder.Body = buff
91+
rt.ServeHTTP(httpRecorder, req)
92+
if expected.chiRoutePattern == nil {
93+
r.res.chiRoutePattern = nil
94+
}
95+
assert.Equal(t, expected, r.res)
96+
}
97+
3398
func TestPathProcessor(t *testing.T) {
3499
testProcess := func(pattern, uri string, expectedPathParams map[string]string) {
35100
chiCtx := chi.NewRouteContext()
@@ -51,42 +116,10 @@ func TestPathProcessor(t *testing.T) {
51116
}
52117

53118
func TestRouter(t *testing.T) {
54-
buff := &bytes.Buffer{}
55-
recorder := httptest.NewRecorder()
56-
recorder.Body = buff
57-
58-
type resultStruct struct {
59-
method string
60-
pathParams map[string]string
61-
handlerMarks []string
62-
chiRoutePattern *string
63-
}
64-
65-
var res resultStruct
66-
h := func(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
67-
mark := util.OptionalArg(optMark, "")
68-
return func(resp http.ResponseWriter, req *http.Request) {
69-
chiCtx := chi.RouteContext(req.Context())
70-
res.method = req.Method
71-
res.pathParams = chiURLParamsToMap(chiCtx)
72-
res.chiRoutePattern = new(chiCtx.RoutePattern())
73-
if mark != "" {
74-
res.handlerMarks = append(res.handlerMarks, mark)
75-
}
76-
}
77-
}
78-
79-
stopMark := func(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
80-
mark := util.OptionalArg(optMark, "")
81-
return func(resp http.ResponseWriter, req *http.Request) {
82-
if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) {
83-
h(stop)(resp, req)
84-
resp.WriteHeader(http.StatusOK)
85-
} else if mark != "" {
86-
res.handlerMarks = append(res.handlerMarks, mark)
87-
}
88-
}
89-
}
119+
type resultStruct = testResult
120+
resRecorder := &testRecorder{}
121+
h := resRecorder.handle
122+
stopMark := resRecorder.stop
90123

91124
r := NewRouter()
92125
r.NotFound(h("not-found:/"))
@@ -123,15 +156,7 @@ func TestRouter(t *testing.T) {
123156

124157
testRoute := func(t *testing.T, methodPath string, expected resultStruct) {
125158
t.Run(methodPath, func(t *testing.T) {
126-
res = resultStruct{}
127-
methodPathFields := strings.Fields(methodPath)
128-
req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil)
129-
assert.NoError(t, err)
130-
r.ServeHTTP(recorder, req)
131-
if expected.chiRoutePattern == nil {
132-
res.chiRoutePattern = nil
133-
}
134-
assert.Equal(t, expected, res)
159+
resRecorder.test(t, r, methodPath, expected)
135160
})
136161
}
137162

@@ -273,3 +298,27 @@ func TestRouteNormalizePath(t *testing.T) {
273298
testPath("/v2/", paths{EscapedPath: "/v2", RawPath: "/v2", Path: "/v2"})
274299
testPath("/v2/%2f", paths{EscapedPath: "/v2/%2f", RawPath: "/v2/%2f", Path: "/v2//"})
275300
}
301+
302+
func TestPreMiddlewareProvider(t *testing.T) {
303+
resRecorder := &testRecorder{}
304+
h := resRecorder.handle
305+
p := resRecorder.provider
306+
307+
r := NewRouter()
308+
r.Use(h("static"))
309+
r.Get("/a/1", h("mid"), PreMiddlewareProvider(p("pre")), h("end1"))
310+
311+
sub := NewRouter()
312+
sub.Use(h("sub"))
313+
sub.Get("/2", h("mid"), PreMiddlewareProvider(p("pre")), h("end2"))
314+
r.Mount("/a", sub)
315+
316+
resRecorder.test(t, r, "GET /a/1", testResult{
317+
method: "GET",
318+
handlerMarks: []string{"static", "pre", "mid", "end1"},
319+
})
320+
resRecorder.test(t, r, "GET /a/2", testResult{
321+
method: "GET",
322+
handlerMarks: []string{"static", "sub", "pre", "mid", "end2"},
323+
})
324+
}

routers/api/v1/api.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ import (
7777
repo_model "code.gitea.io/gitea/models/repo"
7878
"code.gitea.io/gitea/models/unit"
7979
user_model "code.gitea.io/gitea/models/user"
80-
"code.gitea.io/gitea/modules/graceful"
8180
"code.gitea.io/gitea/modules/log"
8281
"code.gitea.io/gitea/modules/setting"
8382
api "code.gitea.io/gitea/modules/structs"
@@ -756,13 +755,9 @@ func buildAuthGroup() *auth.Group {
756755
&auth.Basic{}, // FIXME: this should be removed once we don't allow basic auth in API
757756
)
758757
if setting.Service.EnableReverseProxyAuthAPI {
759-
group.Add(&auth.ReverseProxy{})
758+
group.Add(&auth.ReverseProxy{}) // TODO: does it still make sense to support reverse proxy auth in API?
760759
}
761-
762-
if setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext()) {
763-
group.Add(&auth.SSPI{}) // it MUST be the last, see the comment of SSPI
764-
}
765-
760+
// others: API doesn't support SSPI auth because the caller should use token
766761
return group
767762
}
768763

0 commit comments

Comments
 (0)