Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
pull_request:
branches:
- master
workflow_dispatch:

name: ci

Expand Down
239 changes: 239 additions & 0 deletions providers/cognito/cognito.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package cognito

import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"

"github.com/markbates/goth"
"golang.org/x/oauth2"
)

// Provider is the implementation of `goth.Provider` for accessing AWS Cognito.
// New takes 3 parameters all from the Cognito console:
// - The client ID
// - The client secret
// - The base URL for your service, either a custom domain or cognito pool based URL
// You need to ensure that the source login URL is whitelisted as a login page in the client configuration in the cognito console.
// GOTH does not provide a full token logout, to do that you need to do it in your code.
// If you do not perform a full logout their existing token will be used on a login and the user won't be prompted to login until after expiry.
// To perform a logout
// - Destroy your session (or however else you handle the logout internally)
// - redirect to https://CUSTOM_DOMAIN.auth.us-east-1.amazoncognito.com/logout?client_id=clinet_id&logout_uri=http://localhost:8080/
// (or whatever your login/start page is).
// - Note that this page needs to be white-labeled as a logout page in the cognito console as well.

// This is based upon the implementation for okta

type Provider struct {
ClientKey string
Secret string
CallbackURL string
HTTPClient *http.Client
config *oauth2.Config
providerName string
issuerURL string
profileURL string
}

// New creates a new AWS Cognito provider and sets up important connection details.
// You should always call `cognito.New` to get a new provider. Never try to
// create one manually.
func New(clientID, secret, baseUrl, callbackURL string, scopes ...string) *Provider {
issuerURL := baseUrl + "/oauth2/default"
authURL := baseUrl + "/oauth2/authorize"
tokenURL := baseUrl + "/oauth2/token"
profileURL := baseUrl + "/oauth2/userInfo"
return NewCustomisedURL(clientID, secret, callbackURL, authURL, tokenURL, issuerURL, profileURL, scopes...)
}

// NewCustomisedURL is similar to New(...) but can be used to set custom URLs to connect to
func NewCustomisedURL(clientID, secret, callbackURL, authURL, tokenURL, issuerURL, profileURL string, scopes ...string) *Provider {
p := &Provider{
ClientKey: clientID,
Secret: secret,
CallbackURL: callbackURL,
providerName: "cognito",
issuerURL: issuerURL,
profileURL: profileURL,
}
p.config = newConfig(p, authURL, tokenURL, scopes)
return p
}

// Name is the name used to retrieve this provider later.
func (p *Provider) Name() string {
return p.providerName
}

// SetName is to update the name of the provider (needed in case of multiple providers of 1 type)
func (p *Provider) SetName(name string) {
p.providerName = name
}

func (p *Provider) Client() *http.Client {
return goth.HTTPClientWithFallBack(p.HTTPClient)
}

// Debug is a no-op for the aws package.
func (p *Provider) Debug(debug bool) {
if debug {
fmt.Println("WARNING: Debug request for goth/providers/cognito but no debug is available")
}
}

// BeginAuth asks AWS for an authentication end-point.
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
return &Session{
AuthURL: p.config.AuthCodeURL(state),
}, nil
}

// FetchUser will go to aws and access basic information about the user.
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
sess := session.(*Session)
user := goth.User{
AccessToken: sess.AccessToken,
Provider: p.Name(),
RefreshToken: sess.RefreshToken,
ExpiresAt: sess.ExpiresAt,
UserID: sess.UserID,
}

if user.AccessToken == "" {
// data is not yet retrieved since accessToken is still empty
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName)
}

req, err := http.NewRequest("GET", p.profileURL, nil)
if err != nil {
return user, err
}
req.Header.Set("Authorization", "Bearer "+sess.AccessToken)
response, err := p.Client().Do(req)
if err != nil {
if response != nil {
_ = response.Body.Close()
}
return user, err
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(response.Body)

if response.StatusCode != http.StatusOK {
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
}

bits, err := ioutil.ReadAll(response.Body)
if err != nil {
return user, err
}

err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
if err != nil {
return user, err
}

err = userFromReader(bytes.NewReader(bits), &user)

return user, err
}

func newConfig(provider *Provider, authURL, tokenURL string, scopes []string) *oauth2.Config {
c := &oauth2.Config{
ClientID: provider.ClientKey,
ClientSecret: provider.Secret,
RedirectURL: provider.CallbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
},
Scopes: []string{},
}

if len(scopes) > 0 {
for _, scope := range scopes {
c.Scopes = append(c.Scopes, scope)
}
}
return c
}

// userFromReader
// These are the standard cognito attributes
// from: https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-settings-attributes.html
// all attributes are optional
// it is possible for there to be custom attributes in cognito, but they don't seem to be passed as in the claims
// all the standard claims are mapped into the raw data
func userFromReader(r io.Reader, user *goth.User) error {
u := struct {
ID string `json:"sub"`
Address string `json:"address"`
Birthdate string `json:"birthdate"`
Email string `json:"email"`
EmailVerified string `json:"email_verified"`
FirstName string `json:"given_name"`
LastName string `json:"family_name"`
MiddleName string `json:"middle_name"`
Name string `json:"name"`
NickName string `json:"nickname"`
Locale string `json:"locale"`
PhoneNumber string `json:"phone_number"`
PictureURL string `json:"picture"`
ProfileURL string `json:"profile"`
Username string `json:"preferred_username"`
UpdatedAt string `json:"updated_at"`
WebSite string `json:"website"`
Zoneinfo string `json:"zoneinfo"`
}{}

err := json.NewDecoder(r).Decode(&u)
if err != nil {
return err
}

// Ensure all standard claims are in the raw data
rd := make(map[string]interface{})
rd["Address"] = u.Address
rd["Birthdate"] = u.Birthdate
rd["Locale"] = u.Locale
rd["MiddleName"] = u.MiddleName
rd["PhoneNumber"] = u.PhoneNumber
rd["PictureURL"] = u.PictureURL
rd["ProfileURL"] = u.ProfileURL
rd["UpdatedAt"] = u.UpdatedAt
rd["Username"] = u.Username
rd["WebSite"] = u.WebSite
rd["EmailVerified"] = u.EmailVerified

user.UserID = u.ID
user.Email = u.Email
user.Name = u.Name
user.NickName = u.NickName
user.FirstName = u.FirstName
user.LastName = u.LastName
user.AvatarURL = u.PictureURL
user.RawData = rd

return nil
}

// RefreshTokenAvailable refresh token is provided by auth provider or not
func (p *Provider) RefreshTokenAvailable() bool {
return true
}

// RefreshToken get new access token based on the refresh token
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
token := &oauth2.Token{RefreshToken: refreshToken}
ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token)
newToken, err := ts.Token()
if err != nil {
return nil, err
}
return newToken, err
}
67 changes: 67 additions & 0 deletions providers/cognito/cognito_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package cognito

import (
"os"
"testing"

"github.com/markbates/goth"
"github.com/markbates/goth/providers/okta"
"github.com/stretchr/testify/assert"
)

func Test_New(t *testing.T) {
t.Parallel()
a := assert.New(t)
p := provider()

a.Equal(p.ClientKey, os.Getenv("COGNITO_ID"))
a.Equal(p.Secret, os.Getenv("COGNITO_SECRET"))
a.Equal(p.CallbackURL, "/foo")
}

func Test_NewCustomisedURL(t *testing.T) {
t.Parallel()
a := assert.New(t)
p := urlCustomisedURLProvider()
session, err := p.BeginAuth("test_state")
s := session.(*okta.Session)
a.NoError(err)
a.Contains(s.AuthURL, "http://authURL")
}

func Test_Implements_Provider(t *testing.T) {
t.Parallel()
a := assert.New(t)
a.Implements((*goth.Provider)(nil), provider())
}

func Test_BeginAuth(t *testing.T) {
t.Parallel()
a := assert.New(t)
p := provider()
session, err := p.BeginAuth("test_state")
s := session.(*okta.Session)
a.NoError(err)
a.Contains(s.AuthURL, os.Getenv("COGNITO_ISSUER_URL"))
}

func Test_SessionFromJSON(t *testing.T) {
t.Parallel()
a := assert.New(t)

p := provider()
session, err := p.UnmarshalSession(`{"AuthURL":"` + os.Getenv("COGNITO_ISSUER_URL") + `/oauth2/authorize", "AccessToken":"1234567890"}`)
a.NoError(err)

s := session.(*okta.Session)
a.Equal(s.AuthURL, os.Getenv("COGNITO_ISSUER_URL")+"/oauth2/authorize")
a.Equal(s.AccessToken, "1234567890")
}

func provider() *okta.Provider {
return okta.New(os.Getenv("COGNITO_ID"), os.Getenv("COGNITO_SECRET"), os.Getenv("COGNITO_ISSUER_URL"), "/foo")
}

func urlCustomisedURLProvider() *okta.Provider {
return okta.NewCustomisedURL(os.Getenv("CLIENT_ID"), os.Getenv("CLIENT_SECRET"), "/foo", "http://authURL", "http://tokenURL", "http://issuerURL", "http://profileURL")
}
64 changes: 64 additions & 0 deletions providers/cognito/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package cognito

import (
"encoding/json"
"errors"
"strings"
"time"

"github.com/markbates/goth"
)

// Session stores data during the auth process with AWS Cognito.
type Session struct {
AuthURL string
AccessToken string
RefreshToken string
ExpiresAt time.Time
UserID string
}

var _ goth.Session = &Session{}

// GetAuthURL will return the URL set by calling the `BeginAuth` function on the AWS Cognito provider.
func (s Session) GetAuthURL() (string, error) {
if s.AuthURL == "" {
return "", errors.New(goth.NoAuthUrlErrorMessage)
}
return s.AuthURL, nil
}

// Authorize the session with cognito and return the access token to be stored for future use.
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
p := provider.(*Provider)
token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"))
if err != nil {
return "", err
}

if !token.Valid() {
return "", errors.New("invalid token received from provider")
}

s.AccessToken = token.AccessToken
s.RefreshToken = token.RefreshToken
s.ExpiresAt = token.Expiry
return token.AccessToken, err
}

// Marshal the session into a string
func (s Session) Marshal() string {
b, _ := json.Marshal(s)
return string(b)
}

func (s Session) String() string {
return s.Marshal()
}

// UnmarshalSession wil unmarshal a JSON string into a session.
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
s := &Session{}
err := json.NewDecoder(strings.NewReader(data)).Decode(s)
return s, err
}
Loading