From 500b9d734508a2c8ab09457ac9895b895bc86470 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Tue, 25 Feb 2020 22:20:19 -0500
Subject: [PATCH] Add OriginPatterns to AcceptOptions

Closes #194
---
 accept.go       | 73 +++++++++++++++++++++++++++++++------------------
 accept_test.go  | 31 +++++++++++++++++----
 example_test.go | 12 +-------
 3 files changed, 74 insertions(+), 42 deletions(-)

diff --git a/accept.go b/accept.go
index 479138fc..47e20b52 100644
--- a/accept.go
+++ b/accept.go
@@ -9,10 +9,11 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"log"
 	"net/http"
 	"net/textproto"
 	"net/url"
-	"strconv"
+	"path/filepath"
 	"strings"
 
 	"nhooyr.io/websocket/internal/errd"
@@ -25,18 +26,27 @@ type AcceptOptions struct {
 	// reject it, close the connection when c.Subprotocol() == "".
 	Subprotocols []string
 
-	// InsecureSkipVerify disables Accept's origin verification behaviour. By default,
-	// the connection will only be accepted if the request origin is equal to the request
-	// host.
+	// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
 	//
-	// This is only required if you want javascript served from a different domain
-	// to access your WebSocket server.
+	// Deprecated: Use OriginPatterns with a match all pattern of * instead to control
+	// origin authorization yourself.
+	InsecureSkipVerify bool
+
+	// OriginPatterns lists the host patterns for authorized origins.
+	// The request host is always authorized.
+	// Use this to enable cross origin WebSockets.
+	//
+	// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
+	// In such a case, example.com is the origin and chat.example.com is the request host.
+	// One would set this field to []string{"example.com"} to authorize example.com to connect.
 	//
-	// See https://stackoverflow.com/a/37837709/4283659
+	// Each pattern is matched case insensitively against the request origin host
+	// with filepath.Match.
+	// See https://golang.org/pkg/path/filepath/#Match
 	//
 	// Please ensure you understand the ramifications of enabling this.
 	// If used incorrectly your WebSocket server will be open to CSRF attacks.
-	InsecureSkipVerify bool
+	OriginPatterns []string
 
 	// CompressionMode controls the compression mode.
 	// Defaults to CompressionNoContextTakeover.
@@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
 	}
 
 	if !opts.InsecureSkipVerify {
-		err = authenticateOrigin(r)
+		err = authenticateOrigin(r, opts.OriginPatterns)
 		if err != nil {
+			if errors.Is(err, filepath.ErrBadPattern) {
+				log.Printf("websocket: %v", err)
+				err = errors.New(http.StatusText(http.StatusForbidden))
+			}
 			http.Error(w, err.Error(), http.StatusForbidden)
 			return nil, err
 		}
@@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
 	return 0, nil
 }
 
-func authenticateOrigin(r *http.Request) error {
+func authenticateOrigin(r *http.Request, originHosts []string) error {
 	origin := r.Header.Get("Origin")
-	if origin != "" {
-		u, err := url.Parse(origin)
+	if origin == "" {
+		return nil
+	}
+
+	u, err := url.Parse(origin)
+	if err != nil {
+		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
+	}
+
+	if strings.EqualFold(r.Host, u.Host) {
+		return nil
+	}
+
+	for _, hostPattern := range originHosts {
+		matched, err := match(hostPattern, u.Host)
 		if err != nil {
-			return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
+			return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
 		}
-		if !strings.EqualFold(u.Host, r.Host) {
-			return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
+		if matched {
+			return nil
 		}
 	}
-	return nil
+	return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
+}
+
+func match(pattern, s string) (bool, error) {
+	return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
 }
 
 func selectSubprotocol(r *http.Request, subprotocols []string) string {
@@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
 	return copts, nil
 }
 
-// parseExtensionParameter parses the value in the extension parameter p.
-func parseExtensionParameter(p string) (int, bool) {
-	ps := strings.Split(p, "=")
-	if len(ps) == 1 {
-		return 0, false
-	}
-	i, e := strconv.Atoi(strings.Trim(ps[1], `"`))
-	return i, e == nil
-}
-
 func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
 	copts := mode.opts()
 	// The peer must explicitly request it.
diff --git a/accept_test.go b/accept_test.go
index 49667799..40a7b40c 100644
--- a/accept_test.go
+++ b/accept_test.go
@@ -244,10 +244,11 @@ func Test_authenticateOrigin(t *testing.T) {
 	t.Parallel()
 
 	testCases := []struct {
-		name    string
-		origin  string
-		host    string
-		success bool
+		name           string
+		origin         string
+		host           string
+		originPatterns []string
+		success        bool
 	}{
 		{
 			name:    "none",
@@ -278,6 +279,26 @@ func Test_authenticateOrigin(t *testing.T) {
 			host:    "example.com",
 			success: true,
 		},
+		{
+			name:   "originPatterns",
+			origin: "https://two.examplE.com",
+			host:   "example.com",
+			originPatterns: []string{
+				"*.example.com",
+				"bar.com",
+			},
+			success: true,
+		},
+		{
+			name:   "originPatternsUnauthorized",
+			origin: "https://two.examplE.com",
+			host:   "example.com",
+			originPatterns: []string{
+				"exam3.com",
+				"bar.com",
+			},
+			success: false,
+		},
 	}
 
 	for _, tc := range testCases {
@@ -288,7 +309,7 @@ func Test_authenticateOrigin(t *testing.T) {
 			r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
 			r.Header.Set("Origin", tc.origin)
 
-			err := authenticateOrigin(r)
+			err := authenticateOrigin(r, tc.originPatterns)
 			if tc.success {
 				assert.Success(t, err)
 			} else {
diff --git a/example_test.go b/example_test.go
index 666914d2..c56e53f3 100644
--- a/example_test.go
+++ b/example_test.go
@@ -6,7 +6,6 @@ import (
 	"context"
 	"log"
 	"net/http"
-	"net/url"
 	"time"
 
 	"nhooyr.io/websocket"
@@ -121,17 +120,8 @@ func Example_writeOnly() {
 // from the origin example.com.
 func Example_crossOrigin() {
 	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		origin := r.Header.Get("Origin")
-		if origin != "" {
-			u, err := url.Parse(origin)
-			if err != nil || u.Host != "example.com" {
-				http.Error(w, "bad origin header", http.StatusForbidden)
-				return
-			}
-		}
-
 		c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
-			InsecureSkipVerify: true,
+			OriginPatterns: []string{"example.com"},
 		})
 		if err != nil {
 			log.Println(err)