Skip to content

Commit a2b6447

Browse files
committed
Improve HTTPS detection
1 parent 950fff3 commit a2b6447

2 files changed

Lines changed: 55 additions & 1 deletion

File tree

common/utils.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,45 @@ func TruncateValidationErrorMessage(msg string) string {
1919

2020
func GetFullURL(req *http.Request) string {
2121
scheme := "http"
22-
if req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https" {
22+
if req.TLS != nil || isHTTPS(req.Header) {
2323
scheme = "https"
2424
}
2525
return fmt.Sprintf("%s://%s%s", scheme, req.Host, req.URL.String())
2626
}
2727

28+
func isHTTPS(header http.Header) bool {
29+
for _, key := range []string{
30+
"X-Forwarded-Proto",
31+
"X-Forwarded-Protocol",
32+
"X-Forwarded-Scheme",
33+
"X-Url-Scheme",
34+
"X-Scheme",
35+
} {
36+
if v := header.Get(key); v != "" {
37+
scheme, _, _ := strings.Cut(v, ",")
38+
if strings.TrimSpace(strings.ToLower(scheme)) == "https" {
39+
return true
40+
}
41+
}
42+
}
43+
if v := header.Get("Forwarded"); v != "" {
44+
for _, element := range strings.Split(v, ",") {
45+
for _, param := range strings.Split(element, ";") {
46+
param = strings.TrimSpace(param)
47+
if k, val, ok := strings.Cut(param, "="); ok && strings.ToLower(strings.TrimSpace(k)) == "proto" {
48+
if strings.ToLower(strings.Trim(strings.TrimSpace(val), "\"")) == "https" {
49+
return true
50+
}
51+
}
52+
}
53+
}
54+
}
55+
if strings.EqualFold(header.Get("Front-End-Https"), "on") || strings.EqualFold(header.Get("X-Forwarded-Ssl"), "on") {
56+
return true
57+
}
58+
return false
59+
}
60+
2861
func ParseContentLength(contentLength string) int64 {
2962
if contentLength != "" {
3063
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil {

common/utils_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,27 @@ func TestUtils(t *testing.T) {
1818
assert.Equal(t, "https://example.com/test?q=1", GetFullURL(req))
1919
})
2020

21+
t.Run("isHTTPS", func(t *testing.T) {
22+
check := func(header, value string) bool {
23+
h := http.Header{}
24+
h.Set(header, value)
25+
return isHTTPS(h)
26+
}
27+
assert.True(t, check("X-Forwarded-Proto", "https"))
28+
assert.True(t, check("X-Forwarded-Proto", "https, http"))
29+
assert.False(t, check("X-Forwarded-Proto", "http"))
30+
assert.True(t, check("X-Forwarded-Protocol", "https"))
31+
assert.True(t, check("X-Forwarded-Scheme", "https"))
32+
assert.True(t, check("X-Url-Scheme", "https"))
33+
assert.True(t, check("X-Scheme", "https"))
34+
assert.True(t, check("Forwarded", "for=192.0.2.1;proto=https;host=example.com"))
35+
assert.True(t, check("Forwarded", "proto=\"https\""))
36+
assert.False(t, check("Forwarded", "for=192.0.2.1;proto=http"))
37+
assert.True(t, check("Front-End-Https", "on"))
38+
assert.True(t, check("X-Forwarded-Ssl", "on"))
39+
assert.False(t, isHTTPS(http.Header{}))
40+
})
41+
2142
t.Run("ParseContentLength", func(t *testing.T) {
2243
assert.Equal(t, int64(-1), ParseContentLength(""))
2344
assert.Equal(t, int64(-1), ParseContentLength("invalid"))

0 commit comments

Comments
 (0)