Skip to content

Commit ca2d7c1

Browse files
committed
feat: Support for tenantID in azuread provider
Signed-off-by: Pedro Parra Ortega <[email protected]> fix: revert non desired changes Signed-off-by: Pedro Parra Ortega <[email protected]>
1 parent 0c63ed9 commit ca2d7c1

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

gothic/gothic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ import (
1515
"encoding/base64"
1616
"errors"
1717
"fmt"
18+
"github.com/go-chi/chi/v5"
1819
"io"
1920
"io/ioutil"
2021
"net/http"
2122
"net/url"
2223
"os"
2324
"strings"
2425

25-
"github.com/go-chi/chi/v5"
2626
"github.com/gorilla/mux"
2727
"github.com/gorilla/sessions"
2828
"github.com/markbates/goth"

providers/azuread/azuread.go

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,50 @@ import (
1616
)
1717

1818
const (
19-
authURL string = "https://login.microsoftonline.com/common/oauth2/authorize"
20-
tokenURL string = "https://login.microsoftonline.com/common/oauth2/token"
19+
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/authorize"
20+
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/token"
2121
endpointProfile string = "https://graph.windows.net/me?api-version=1.6"
2222
graphAPIResource string = "https://graph.windows.net/"
23+
commonTenant string = "common"
2324
)
2425

2526
// New creates a new AzureAD provider, and sets up important connection details.
2627
// You should always call `AzureAD.New` to get a new Provider. Never try to create
2728
// one manually.
28-
func New(clientKey, secret, callbackURL string, resources []string, scopes ...string) *Provider {
29+
func New(clientKey, secret, callbackURL string, opts ProviderOpts) *Provider {
2930
p := &Provider{
3031
ClientKey: clientKey,
3132
Secret: secret,
3233
CallbackURL: callbackURL,
3334
providerName: "azuread",
3435
}
3536

36-
p.resources = make([]string, 0, 1+len(resources))
37+
p.resources = make([]string, 0, 1+len(opts.Resources))
3738
p.resources = append(p.resources, graphAPIResource)
38-
p.resources = append(p.resources, resources...)
39+
p.resources = append(p.resources, opts.Resources...)
3940

40-
p.config = newConfig(p, scopes)
41+
p.config = newConfig(p, opts)
4142
return p
4243
}
4344

4445
// Provider is the implementation of `goth.Provider` for accessing AzureAD.
45-
type Provider struct {
46-
ClientKey string
47-
Secret string
48-
CallbackURL string
49-
HTTPClient *http.Client
50-
config *oauth2.Config
51-
providerName string
52-
resources []string
53-
}
46+
type (
47+
Provider struct {
48+
ClientKey string
49+
Secret string
50+
CallbackURL string
51+
HTTPClient *http.Client
52+
config *oauth2.Config
53+
providerName string
54+
resources []string
55+
}
56+
57+
ProviderOpts struct {
58+
Resources []string
59+
Scopes []string
60+
TenantID string
61+
}
62+
)
5463

5564
// Name is the name used to retrieve this provider later.
5665
func (p *Provider) Name() string {
@@ -132,20 +141,20 @@ func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
132141
return newToken, err
133142
}
134143

135-
func newConfig(provider *Provider, scopes []string) *oauth2.Config {
144+
func newConfig(provider *Provider, opts ProviderOpts) *oauth2.Config {
136145
c := &oauth2.Config{
137146
ClientID: provider.ClientKey,
138147
ClientSecret: provider.Secret,
139148
RedirectURL: provider.CallbackURL,
140149
Endpoint: oauth2.Endpoint{
141-
AuthURL: authURL,
142-
TokenURL: tokenURL,
150+
AuthURL: authURL(opts.TenantID),
151+
TokenURL: tokenURL(opts.TenantID),
143152
},
144153
Scopes: []string{},
145154
}
146155

147-
if len(scopes) > 0 {
148-
for _, scope := range scopes {
156+
if len(opts.Scopes) > 0 {
157+
for _, scope := range opts.Scopes {
149158
c.Scopes = append(c.Scopes, scope)
150159
}
151160
} else {
@@ -185,3 +194,19 @@ func userFromReader(r io.Reader, user *goth.User) error {
185194
func authorizationHeader(session *Session) (string, string) {
186195
return "Authorization", fmt.Sprintf("Bearer %s", session.AccessToken)
187196
}
197+
198+
func authURL(tenantID string) string {
199+
if tenantID != "" {
200+
return fmt.Sprintf(authURLTemplate, tenantID)
201+
} else {
202+
return fmt.Sprintf(authURLTemplate, commonTenant)
203+
}
204+
}
205+
206+
func tokenURL(tenantID string) string {
207+
if tenantID != "" {
208+
return fmt.Sprintf(tokenURLTemplate, tenantID)
209+
} else {
210+
return fmt.Sprintf(tokenURLTemplate, commonTenant)
211+
}
212+
}

providers/azuread/azuread_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,5 @@ func Test_SessionFromJSON(t *testing.T) {
5151
}
5252

5353
func azureadProvider() *azuread.Provider {
54-
return azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "/foo", nil)
54+
return azuread.New(os.Getenv("AZUREAD_KEY"), os.Getenv("AZUREAD_SECRET"), "/foo", azuread.ProviderOpts{})
5555
}

0 commit comments

Comments
 (0)