Skip to content

Commit 1e0ae83

Browse files
config: apply DialContextFunc to OAuth2 token-fetch transport (#911)
newOauth2TokenSource builds its own http.Transport to fetch tokens but doesn't set DialContext on it. Any DialContextFunc passed via WithDialContextFunc is applied to the main request transport but silently skipped for the token endpoint. Set DialContext on the token transport the same way it is set on the main one. When dialContextFunc is nil the behaviour is unchanged since http.Transport falls back to its default dialer. Added TestOAuth2DialContextFunc to verify that WithDialContextFunc blocks the token endpoint fetch, not only the final request. Signed-off-by: Yuri Tseretyan <yuriy.tseretyan@grafana.com>
1 parent b51d01b commit 1e0ae83

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

config/http_config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,7 @@ func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, clientCred
987987
IdleConnTimeout: 10 * time.Second,
988988
TLSHandshakeTimeout: 10 * time.Second,
989989
ExpectContinueTimeout: 1 * time.Second,
990+
DialContext: rt.opts.dialContextFunc,
990991
}, nil
991992
}
992993

config/http_config_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,40 @@ func TestOAuth2UserAgent(t *testing.T) {
15991599
require.Equalf(t, "Bearer 12345", authorization, "Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
16001600
}
16011601

1602+
func TestOAuth2DialContextFunc(t *testing.T) {
1603+
tokenServerInvoked := false
1604+
tokenTS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1605+
tokenServerInvoked = true
1606+
res, _ := json.Marshal(oauth2TestServerResponse{
1607+
AccessToken: "12345",
1608+
TokenType: "Bearer",
1609+
})
1610+
w.Header().Add("Content-Type", "application/json")
1611+
_, _ = w.Write(res)
1612+
}))
1613+
defer tokenTS.Close()
1614+
1615+
config := DefaultHTTPClientConfig
1616+
config.OAuth2 = &OAuth2{
1617+
ClientID: "1",
1618+
ClientSecret: "2",
1619+
TokenURL: tokenTS.URL + "/token",
1620+
}
1621+
1622+
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
1623+
return nil, errors.New(ExpectedError)
1624+
}
1625+
1626+
rt, err := NewRoundTripperFromConfig(config, "test_oauth2_dialctx", WithDialContextFunc(dialFn))
1627+
require.NoError(t, err)
1628+
1629+
client := http.Client{Transport: rt}
1630+
_, err = client.Get(tokenTS.URL)
1631+
require.Error(t, err)
1632+
require.Containsf(t, err.Error(), ExpectedError, "expected error from DialContextFunc, got: %v", err)
1633+
require.Falsef(t, tokenServerInvoked, "OAuth2 token endpoint must not be reached when DialContextFunc blocks")
1634+
}
1635+
16021636
func TestHost(t *testing.T) {
16031637
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16041638
require.Equalf(t, "localhost.localdomain", r.Host, "Expected Host header in request to be 'localhost.localdomain', got '%s'", r.Host)

0 commit comments

Comments
 (0)