Skip to content

Commit 33688df

Browse files
committed
move Access-Control-Allow-Origin to httpp.Server
1 parent c6f91d6 commit 33688df

File tree

12 files changed

+241
-232
lines changed

12 files changed

+241
-232
lines changed

internal/api/api.go

Lines changed: 3 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ import (
66
"fmt"
77
"net"
88
"net/http"
9-
"net/url"
109
"os"
1110
"reflect"
12-
"regexp"
1311
"sort"
1412
"strings"
1513
"sync"
@@ -78,71 +76,6 @@ func recordingsOfPath(
7876
return ret
7977
}
8078

81-
var errOriginNotAllowed = errors.New("origin not allowed")
82-
83-
func isOriginAllowed(origin string, allowOrigins []string) (string, error) {
84-
if len(allowOrigins) == 0 {
85-
return "", errOriginNotAllowed
86-
}
87-
88-
for _, o := range allowOrigins {
89-
if o == "*" {
90-
return o, nil
91-
}
92-
}
93-
94-
if origin == "" {
95-
return "", errOriginNotAllowed
96-
}
97-
98-
originURL, err := url.Parse(origin)
99-
if err != nil || originURL.Scheme == "" {
100-
return "", errOriginNotAllowed
101-
}
102-
103-
if originURL.Port() == "" && originURL.Scheme != "" {
104-
switch originURL.Scheme {
105-
case "http":
106-
originURL.Host = net.JoinHostPort(originURL.Host, "80")
107-
case "https":
108-
originURL.Host = net.JoinHostPort(originURL.Host, "443")
109-
}
110-
}
111-
112-
for _, o := range allowOrigins {
113-
allowedURL, errAllowed := url.Parse(o)
114-
if errAllowed != nil {
115-
continue
116-
}
117-
118-
if allowedURL.Port() == "" {
119-
switch allowedURL.Scheme {
120-
case "http":
121-
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "80")
122-
case "https":
123-
allowedURL.Host = net.JoinHostPort(allowedURL.Host, "443")
124-
}
125-
}
126-
127-
if allowedURL.Scheme == originURL.Scheme &&
128-
allowedURL.Host == originURL.Host &&
129-
allowedURL.Port() == originURL.Port() {
130-
return origin, nil
131-
}
132-
133-
if strings.Contains(allowedURL.Host, "*") {
134-
pattern := strings.ReplaceAll(allowedURL.Host, "*.", "(.*\\.)?")
135-
pattern = strings.ReplaceAll(pattern, "*", ".*")
136-
matched, errMatched := regexp.MatchString("^"+pattern+"$", originURL.Host)
137-
if errMatched == nil && matched {
138-
return origin, nil
139-
}
140-
}
141-
}
142-
143-
return "", errOriginNotAllowed
144-
}
145-
14679
type apiAuthManager interface {
14780
Authenticate(req *auth.Request) *auth.Error
14881
RefreshJWTJWKS()
@@ -186,7 +119,7 @@ func (a *API) Initialize() error {
186119
router := gin.New()
187120
router.SetTrustedProxies(a.TrustedProxies.ToTrustedProxies()) //nolint:errcheck
188121

189-
router.Use(a.middlewareOrigin)
122+
router.Use(a.middlewarePreflightRequests)
190123
router.Use(a.middlewareAuth)
191124

192125
group := router.Group("/v3")
@@ -262,6 +195,7 @@ func (a *API) Initialize() error {
262195

263196
a.httpServer = &httpp.Server{
264197
Address: a.Address,
198+
AllowOrigins: a.AllowOrigins,
265199
ReadTimeout: time.Duration(a.ReadTimeout),
266200
WriteTimeout: time.Duration(a.WriteTimeout),
267201
Encryption: a.Encryption,
@@ -301,16 +235,7 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) {
301235
})
302236
}
303237

304-
func (a *API) middlewareOrigin(ctx *gin.Context) {
305-
origin, err := isOriginAllowed(ctx.Request.Header.Get("Origin"), a.AllowOrigins)
306-
if err != nil {
307-
return
308-
}
309-
310-
ctx.Header("Access-Control-Allow-Origin", origin)
311-
ctx.Header("Access-Control-Allow-Credentials", "true")
312-
313-
// preflight requests
238+
func (a *API) middlewarePreflightRequests(ctx *gin.Context) {
314239
if ctx.Request.Method == http.MethodOptions &&
315240
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
316241
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PATCH, DELETE")

internal/api/api_test.go

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package api
33
import (
44
"bytes"
55
"encoding/json"
6-
"errors"
76
"fmt"
87
"io"
98
"net/http"
@@ -119,98 +118,6 @@ func TestPreflightRequest(t *testing.T) {
119118
require.Equal(t, byts, []byte{})
120119
}
121120

122-
func TestMiddlewareOrigin(t *testing.T) {
123-
allowOrigins := []string{}
124-
origin := ""
125-
allowedOrigin, err := isOriginAllowed(origin, allowOrigins)
126-
if err == nil {
127-
t.Fatalf("expected error for empty origin, got nil")
128-
}
129-
if allowedOrigin != "" {
130-
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
131-
}
132-
133-
allowOrigins = []string{"http://example.com"}
134-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
135-
if err == nil {
136-
t.Fatalf("expected error for empty origin with allowed origins, got nil")
137-
}
138-
if allowedOrigin != "" {
139-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
140-
}
141-
142-
allowOrigins = []string{"*"}
143-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
144-
if err != nil {
145-
t.Fatalf("unexpected error for wildcard origin: %v", err)
146-
}
147-
if allowedOrigin != "*" {
148-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
149-
}
150-
151-
origin = "http://example.com"
152-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
153-
if err != nil {
154-
t.Fatalf("unexpected error for matching wildcard: %v", err)
155-
}
156-
if allowedOrigin != "*" {
157-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
158-
}
159-
160-
allowOrigins = []string{"http://example.com", "https://example.org"}
161-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
162-
if err != nil {
163-
t.Fatalf("unexpected error for matching origin: %v", err)
164-
}
165-
if allowedOrigin != origin {
166-
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
167-
}
168-
169-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
170-
if err != nil {
171-
t.Fatalf("unexpected error for matching origin: %v", err)
172-
}
173-
if allowedOrigin != origin {
174-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
175-
}
176-
177-
origin = "https://example.org"
178-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
179-
if err != nil {
180-
t.Fatalf("unexpected error for matching origin: %v", err)
181-
}
182-
if allowedOrigin != origin {
183-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
184-
}
185-
186-
allowedOrigin, err = isOriginAllowed("http://notallowed.com", allowOrigins)
187-
if !errors.Is(err, errOriginNotAllowed) {
188-
t.Fatalf("expected errOriginNotAllowed for disallowed origin, got %v", err)
189-
}
190-
if allowedOrigin != "" {
191-
t.Fatalf("expected empty allowed origin, got %s", allowedOrigin)
192-
}
193-
194-
allowOrigins = []string{"http://*.example.com"}
195-
origin = "http://test.example.com"
196-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
197-
if err != nil {
198-
t.Fatalf("unexpected error for wildcard subdomain: %v", err)
199-
}
200-
if allowedOrigin != origin {
201-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
202-
}
203-
204-
origin = "http://example.com"
205-
allowedOrigin, err = isOriginAllowed(origin, allowOrigins)
206-
if err != nil {
207-
t.Fatalf("unexpected error for exact subdomain match: %v", err)
208-
}
209-
if allowedOrigin != origin {
210-
t.Fatalf("unexpected allowed origin: %s", allowedOrigin)
211-
}
212-
}
213-
214121
func TestInfo(t *testing.T) {
215122
cnf := tempConf(t, "api: yes\n")
216123

internal/metrics/metrics.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,14 @@ func (m *Metrics) Initialize() error {
9898
router := gin.New()
9999
router.SetTrustedProxies(m.TrustedProxies.ToTrustedProxies()) //nolint:errcheck
100100

101-
router.Use(m.middlewareOrigin)
101+
router.Use(m.middlewarePreflightRequests)
102102
router.Use(m.middlewareAuth)
103103

104104
router.GET("/metrics", m.onMetrics)
105105

106106
m.httpServer = &httpp.Server{
107107
Address: m.Address,
108+
AllowOrigins: []string{m.AllowOrigin},
108109
ReadTimeout: time.Duration(m.ReadTimeout),
109110
WriteTimeout: time.Duration(m.WriteTimeout),
110111
Encryption: m.Encryption,
@@ -134,11 +135,7 @@ func (m *Metrics) Log(level logger.Level, format string, args ...any) {
134135
m.Parent.Log(level, "[metrics] "+format, args...)
135136
}
136137

137-
func (m *Metrics) middlewareOrigin(ctx *gin.Context) {
138-
ctx.Header("Access-Control-Allow-Origin", m.AllowOrigin)
139-
ctx.Header("Access-Control-Allow-Credentials", "true")
140-
141-
// preflight requests
138+
func (m *Metrics) middlewarePreflightRequests(ctx *gin.Context) {
142139
if ctx.Request.Method == http.MethodOptions &&
143140
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
144141
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET")

internal/playback/server.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ func (s *Server) Initialize() error {
4141
router := gin.New()
4242
router.SetTrustedProxies(s.TrustedProxies.ToTrustedProxies()) //nolint:errcheck
4343

44-
router.Use(s.middlewareOrigin)
44+
router.Use(s.middlewarePreflightRequests)
4545

4646
router.GET("/list", s.onList)
4747
router.GET("/get", s.onGet)
4848

4949
s.httpServer = &httpp.Server{
5050
Address: s.Address,
51+
AllowOrigins: []string{s.AllowOrigin},
5152
ReadTimeout: time.Duration(s.ReadTimeout),
5253
WriteTimeout: time.Duration(s.WriteTimeout),
5354
Encryption: s.Encryption,
@@ -100,11 +101,7 @@ func (s *Server) safeFindPathConf(name string) (*conf.Path, error) {
100101
return pathConf, err
101102
}
102103

103-
func (s *Server) middlewareOrigin(ctx *gin.Context) {
104-
ctx.Header("Access-Control-Allow-Origin", s.AllowOrigin)
105-
ctx.Header("Access-Control-Allow-Credentials", "true")
106-
107-
// preflight requests
104+
func (s *Server) middlewarePreflightRequests(ctx *gin.Context) {
108105
if ctx.Request.Method == http.MethodOptions &&
109106
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
110107
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET")

internal/pprof/pprof.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ func (pp *PPROF) Initialize() error {
4444
router := gin.New()
4545
router.SetTrustedProxies(pp.TrustedProxies.ToTrustedProxies()) //nolint:errcheck
4646

47-
router.Use(pp.middlewareOrigin)
47+
router.Use(pp.middlewarePreflightRequests)
4848
router.Use(pp.middlewareAuth)
4949

5050
pprof.Register(router)
5151

5252
pp.httpServer = &httpp.Server{
5353
Address: pp.Address,
54+
AllowOrigins: []string{pp.AllowOrigin},
5455
ReadTimeout: time.Duration(pp.ReadTimeout),
5556
WriteTimeout: time.Duration(pp.WriteTimeout),
5657
Encryption: pp.Encryption,
@@ -80,11 +81,7 @@ func (pp *PPROF) Log(level logger.Level, format string, args ...any) {
8081
pp.Parent.Log(level, "[pprof] "+format, args...)
8182
}
8283

83-
func (pp *PPROF) middlewareOrigin(ctx *gin.Context) {
84-
ctx.Header("Access-Control-Allow-Origin", pp.AllowOrigin)
85-
ctx.Header("Access-Control-Allow-Credentials", "true")
86-
87-
// preflight requests
84+
func (pp *PPROF) middlewarePreflightRequests(ctx *gin.Context) {
8885
if ctx.Request.Method == http.MethodOptions &&
8986
ctx.Request.Header.Get("Access-Control-Request-Method") != "" {
9087
ctx.Header("Access-Control-Allow-Methods", "OPTIONS, GET")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package httpp
2+
3+
import (
4+
"net"
5+
"net/http"
6+
"strings"
7+
"testing"
8+
"time"
9+
10+
"github.com/bluenviron/mediamtx/internal/test"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestHandlerFilterRequests(t *testing.T) {
15+
s := &Server{
16+
Address: "localhost:4555",
17+
ReadTimeout: 10 * time.Second,
18+
WriteTimeout: 10 * time.Second,
19+
Parent: test.NilLogger,
20+
Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
21+
w.WriteHeader(http.StatusOK)
22+
}),
23+
}
24+
err := s.Initialize()
25+
require.NoError(t, err)
26+
defer s.Close()
27+
28+
conn, err := net.Dial("tcp", "localhost:4555")
29+
require.NoError(t, err)
30+
defer conn.Close()
31+
32+
_, err = conn.Write([]byte("OPTIONS / HTTP/1.1\n" +
33+
"Host: localhost:8889\n\n"))
34+
require.NoError(t, err)
35+
36+
buf := make([]byte, 200)
37+
n, err := conn.Read(buf)
38+
require.NoError(t, err)
39+
40+
res := strings.Split(string(buf[:n]), "\r\n")
41+
require.Equal(t, "HTTP/1.1 200 OK", res[0])
42+
}

0 commit comments

Comments
 (0)