Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions providers/reddit/reddit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package reddit

import (
"encoding/json"
"fmt"
"github.com/markbates/goth"
"golang.org/x/oauth2"
"io"
"net/http"
"time"
)

const (
authURL = "https://www.reddit.com/api/v1/authorize"
)

type Provider struct {
providerName string
duration string
config oauth2.Config
client http.Client
// TODO: userURL should be a constant
userURL string
}

func New(clientID string, clientSecret string, redirectURI string, duration string, tokenEndpoint string, userURL string, scopes ...string) Provider {
return Provider{
providerName: "reddit",
duration: duration,
config: oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenEndpoint,
AuthStyle: 0,
},
RedirectURL: redirectURI,
Scopes: scopes,
},
client: http.Client{},
userURL: userURL,
}
}

func (p *Provider) Name() string {
return p.providerName
}

func (p *Provider) SetName(name string) {
p.providerName = name
}

func (p *Provider) UnmarshalSession(s string) (goth.Session, error) {
session := &Session{}
err := json.Unmarshal([]byte(s), session)
if err != nil {
return nil, err
}

return session, nil
}

func (p *Provider) Debug(b bool) {}

func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
return nil, nil
}

func (p *Provider) RefreshTokenAvailable() bool {
return true
}

func (p *Provider) BeginAuth(state string) (goth.Session, error) {
authCodeOption := oauth2.SetAuthURLParam("duration", p.duration)
return &Session{AuthURL: p.config.AuthCodeURL(state, authCodeOption)}, nil
}

type redditResponse struct {
Id string `json:"id"`
Name string `json:"name"`
}

func (p *Provider) FetchUser(s goth.Session) (goth.User, error) {
session := s.(*Session)
request, err := http.NewRequest("GET", p.userURL, nil)
if err != nil {
return goth.User{}, err
}

bearer := "Bearer " + session.AccessToken
request.Header.Add("Authorization", bearer)

res, err := p.client.Do(request)
if err != nil {
return goth.User{}, err
}

defer res.Body.Close()

if res.StatusCode != http.StatusOK {
if res.StatusCode == http.StatusForbidden {
return goth.User{}, fmt.Errorf("%s responded with a %s because you did not provide the identity scope which is required to fetch user profile", p.providerName, res.Status)
}
return goth.User{}, fmt.Errorf("%s responded with a %d trying to fetch user profile", p.providerName, res.StatusCode)
}

bits, err := io.ReadAll(res.Body)
if err != nil {
return goth.User{}, err
}

var r redditResponse

err = json.Unmarshal(bits, &r)
if err != nil {
return goth.User{}, err
}

gothUser := goth.User{
RawData: nil,
Provider: p.Name(),
Name: r.Name,
UserID: r.Id,
AccessToken: session.AccessToken,
RefreshToken: session.RefreshToken,
ExpiresAt: time.Time{},
}

err = json.Unmarshal(bits, &gothUser.RawData)
if err != nil {
return goth.User{}, err
}

return gothUser, nil
}
88 changes: 88 additions & 0 deletions providers/reddit/reddit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package reddit

import (
"encoding/json"
"github.com/markbates/goth"
"golang.org/x/oauth2"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
)

var response = redditResponse{
Id: "invader21",
Name: "JohnDoe",
}

func TestProvider(t *testing.T) {
t.Run("create a new provider", func(t *testing.T) {
got := New("client id", "client secret", "redirect uri", "duration", "example.com", "userURL", "scope1", "scope2", "scope 3")
want := Provider{
providerName: "reddit",
duration: "duration",
config: oauth2.Config{
ClientID: "client id",
ClientSecret: "client secret",
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: "example.com",
AuthStyle: 0,
},
RedirectURL: "redirect uri",
Scopes: []string{"scope1", "scope2", "scope 3"},
},
userURL: "userURL",
}

if !reflect.DeepEqual(got, want) {
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t \033[31;1;4mwant\033[0m %+v", got, want)
}
})

t.Run("fetch reddit user that created the given session", func(t *testing.T) {
redditServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
b, err := json.Marshal(response)
if err != nil {
t.Fatal(err)
}
writer.Header().Add("Content-Type", "application/json")
writer.Write(b)
}))

defer redditServer.Close()

userURL := redditServer.URL
p := New("client id", "client secret", "redirect uri", "duration", "example.com", userURL, "scope1", "scope2", "scope 3")
s := &Session{
AuthURL: "",
AccessToken: "i am a token",
TokenType: "bearer",
RefreshToken: "your refresh token",
Expiry: time.Time{},
}

got, err := p.FetchUser(s)
if err != nil {
t.Errorf("did not expect an error: %s", err)
}

want := goth.User{
RawData: map[string]interface{}{
"id": "invader21",
"name": "JohnDoe",
},
Provider: "reddit",
Name: "JohnDoe",
UserID: "invader21",
AccessToken: "i am a token",
RefreshToken: "your refresh token",
ExpiresAt: time.Time{},
}

if !reflect.DeepEqual(got, want) {
t.Errorf("\033[31;1;4mgot\033[0m %+v, \n\t\t \033[31;1;4mwant\033[0m %+v", got, want)
}
})
}
46 changes: 46 additions & 0 deletions providers/reddit/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package reddit

import (
"context"
"encoding/json"
"errors"
"github.com/markbates/goth"
"golang.org/x/oauth2"
"time"
)

type Session struct {
AuthURL string
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Expiry time.Time `json:"expiry,omitempty"`
}

func (s *Session) GetAuthURL() (string, error) {
return s.AuthURL, nil
}

func (s *Session) Marshal() string {
b, _ := json.Marshal(s)
return string(b)
}

func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
p := provider.(*Provider)
t, err := p.config.Exchange(context.WithValue(context.Background(), oauth2.HTTPClient, p.client), params.Get("code"))
if err != nil {
return "", err
}

if !t.Valid() {
return "", errors.New("invalid token received from provider")
}

s.AccessToken = t.AccessToken
s.TokenType = t.TokenType
s.RefreshToken = t.RefreshToken
s.Expiry = t.Expiry

return s.AccessToken, nil
}
122 changes: 122 additions & 0 deletions providers/reddit/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package reddit

import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

var validAuthResponseTestData = struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}{
AccessToken: "i am a token",
TokenType: "type",
ExpiresIn: 120,
Scope: "identity",
RefreshToken: "your refresh token",
}

var invalidAuthResponseTestData = struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}{
AccessToken: "",
TokenType: "type",
ExpiresIn: 120,
Scope: "identity",
RefreshToken: "Your refresh token",
}

func TestSession(t *testing.T) {
t.Run("gets the URL for the authentication end-point for the provider", func(t *testing.T) {
s := Session{AuthURL: "example.com"}
got, err := s.GetAuthURL()
if err != nil {
t.Fatal("should return a url string")
}

want := "example.com"

if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("generates a string representation of the session", func(t *testing.T) {
s := Session{
AuthURL: "example",
}
got := s.Marshal()
want := `{"AuthURL":"example","access_token":"","expiry":"0001-01-01T00:00:00Z"}`

if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("return an access token", func(t *testing.T) {

s := Session{AuthURL: "example.com"}
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
b, err := json.Marshal(validAuthResponseTestData)
if err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
writer.Header().Add("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
writer.Write(b)
}))

tokenURL := authServer.URL

p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2")
u := url.Values{}
u.Set("code", "12345678")

got, err := s.Authorize(&p, u)
if err != nil {
t.Fatal("did not expect an error: ", err)
}

want := validAuthResponseTestData.AccessToken

if got != want {
t.Errorf("got %q want %q", got, want)
}
})

t.Run("validates access token", func(t *testing.T) {
s := Session{AuthURL: "example.com"}
authServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
b, err := json.Marshal(invalidAuthResponseTestData)
if err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
writer.Header().Add("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
writer.Write(b)
}))

tokenURL := authServer.URL

p := New("CLIENT_ID", "CLIENT_SECRET", "URI", "DURATION", tokenURL, "SCOPE_STRING1", "SCOPE_STRING2")
u := url.Values{}
u.Set("code", "12345678")

_, err := s.Authorize(&p, u)
if err == nil {
t.Errorf("expected an error but didn't get one")
}
})
}
Loading