Skip to content

Refactor Find Sources and fix bug when view a user who belongs to an unactive auth source #27798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 3, 2023
2 changes: 1 addition & 1 deletion cmd/admin_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error {
return err
}

authSources, err := auth_model.Sources(ctx)
authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{})
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/activities/statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource = auth.CountSources(ctx)
stats.Counter.AuthSource = auth.CountSources(ctx, auth.FindSourcesOptions{})
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
Expand Down
9 changes: 0 additions & 9 deletions models/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,15 +631,6 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error {
return util.ErrNotExist
}

// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
func GetActiveOAuth2ProviderSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}

// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) {
authSource := new(Source)
Expand Down
49 changes: 21 additions & 28 deletions models/auth/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"

"xorm.io/builder"
"xorm.io/xorm"
"xorm.io/xorm/convert"
)
Expand Down Expand Up @@ -240,37 +241,26 @@ func CreateSource(ctx context.Context, source *Source) error {
return err
}

// Sources returns a slice of all login sources found in DB.
func Sources(ctx context.Context) ([]*Source, error) {
auths := make([]*Source, 0, 6)
return auths, db.GetEngine(ctx).Find(&auths)
type FindSourcesOptions struct {
IsActive util.OptionalBool
LoginType Type
}

// SourcesByType returns all sources of the specified type
func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err
func (opts FindSourcesOptions) ToConds() builder.Cond {
conds := builder.NewCond()
if !opts.IsActive.IsNone() {
conds = conds.And(builder.Eq{"is_active": opts.IsActive.IsTrue()})
}
return sources, nil
}

// AllActiveSources returns all active sources
func AllActiveSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 5)
if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err
if opts.LoginType != NoType {
conds = conds.And(builder.Eq{"`type`": opts.LoginType})
}
return sources, nil
return conds
}

// ActiveSources returns all active sources of the specified type
func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
return nil, err
}
return sources, nil
// FindSources returns a slice of login sources found in DB according to given conditions.
func FindSources(ctx context.Context, opts FindSourcesOptions) ([]*Source, error) {
auths := make([]*Source, 0, 6)
return auths, db.GetEngine(ctx).Where(opts.ToConds()).Find(&auths)
}

// IsSSPIEnabled returns true if there is at least one activated login
Expand All @@ -279,7 +269,10 @@ func IsSSPIEnabled(ctx context.Context) bool {
if !db.HasEngine {
return false
}
sources, err := ActiveSources(ctx, SSPI)
sources, err := FindSources(ctx, FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
LoginType: SSPI,
})
if err != nil {
log.Error("ActiveSources: %v", err)
return false
Expand Down Expand Up @@ -354,8 +347,8 @@ func UpdateSource(ctx context.Context, source *Source) error {
}

// CountSources returns number of login sources.
func CountSources(ctx context.Context) int64 {
count, _ := db.GetEngine(ctx).Count(new(Source))
func CountSources(ctx context.Context, opts FindSourcesOptions) int64 {
count, _ := db.GetEngine(ctx).Where(opts.ToConds()).Count(new(Source))
return count
}

Expand Down
14 changes: 7 additions & 7 deletions routers/web/admin/auths.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func Authentications(ctx *context.Context) {
ctx.Data["PageIsAdminAuthentications"] = true

var err error
ctx.Data["Sources"], err = auth.Sources(ctx)
ctx.Data["Sources"], err = auth.FindSources(ctx, auth.FindSourcesOptions{})
if err != nil {
ctx.ServerError("auth.Sources", err)
return
}

ctx.Data["Total"] = auth.CountSources(ctx)
ctx.Data["Total"] = auth.CountSources(ctx, auth.FindSourcesOptions{})
ctx.HTML(http.StatusOK, tplAuths)
}

Expand Down Expand Up @@ -99,7 +99,7 @@ func NewAuthSource(ctx *context.Context) {
ctx.Data["AuthSources"] = authSources
ctx.Data["SecurityProtocols"] = securityProtocols
ctx.Data["SMTPAuths"] = smtp.Authenticators
oauth2providers := oauth2.GetOAuth2Providers()
oauth2providers := oauth2.GetSupportedOAuth2Providers()
ctx.Data["OAuth2Providers"] = oauth2providers

ctx.Data["SSPIAutoCreateUsers"] = true
Expand Down Expand Up @@ -242,7 +242,7 @@ func NewAuthSourcePost(ctx *context.Context) {
ctx.Data["AuthSources"] = authSources
ctx.Data["SecurityProtocols"] = securityProtocols
ctx.Data["SMTPAuths"] = smtp.Authenticators
oauth2providers := oauth2.GetOAuth2Providers()
oauth2providers := oauth2.GetSupportedOAuth2Providers()
ctx.Data["OAuth2Providers"] = oauth2providers

ctx.Data["SSPIAutoCreateUsers"] = true
Expand Down Expand Up @@ -284,7 +284,7 @@ func NewAuthSourcePost(ctx *context.Context) {
ctx.RenderWithErr(err.Error(), tplAuthNew, form)
return
}
existing, err := auth.SourcesByType(ctx, auth.SSPI)
existing, err := auth.FindSources(ctx, auth.FindSourcesOptions{LoginType: auth.SSPI})
if err != nil || len(existing) > 0 {
ctx.Data["Err_Type"] = true
ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_of_type_exist"), tplAuthNew, form)
Expand Down Expand Up @@ -334,7 +334,7 @@ func EditAuthSource(ctx *context.Context) {

ctx.Data["SecurityProtocols"] = securityProtocols
ctx.Data["SMTPAuths"] = smtp.Authenticators
oauth2providers := oauth2.GetOAuth2Providers()
oauth2providers := oauth2.GetSupportedOAuth2Providers()
ctx.Data["OAuth2Providers"] = oauth2providers

source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid"))
Expand Down Expand Up @@ -368,7 +368,7 @@ func EditAuthSourcePost(ctx *context.Context) {
ctx.Data["PageIsAdminAuthentications"] = true

ctx.Data["SMTPAuths"] = smtp.Authenticators
oauth2providers := oauth2.GetOAuth2Providers()
oauth2providers := oauth2.GetSupportedOAuth2Providers()
ctx.Data["OAuth2Providers"] = oauth2providers

source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid"))
Expand Down
10 changes: 7 additions & 3 deletions routers/web/admin/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ func NewUser(ctx *context.Context) {

ctx.Data["login_type"] = "0-0"

sources, err := auth.Sources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
})
if err != nil {
ctx.ServerError("auth.Sources", err)
return
Expand All @@ -109,7 +111,9 @@ func NewUserPost(ctx *context.Context) {
ctx.Data["DefaultUserVisibilityMode"] = setting.Service.DefaultUserVisibilityMode
ctx.Data["AllowedUserVisibilityModes"] = setting.Service.AllowedUserVisibilityModesSlice.ToVisibleTypeSlice()

sources, err := auth.Sources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
})
if err != nil {
ctx.ServerError("auth.Sources", err)
return
Expand Down Expand Up @@ -230,7 +234,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User {
ctx.Data["LoginSource"] = &auth.Source{}
}

sources, err := auth.Sources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{})
if err != nil {
ctx.ServerError("auth.Sources", err)
return nil
Expand Down
12 changes: 4 additions & 8 deletions routers/web/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,11 @@ func SignIn(ctx *context.Context) {
return
}

orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx)
oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue)
if err != nil {
ctx.ServerError("UserSignIn", err)
return
}
ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers
ctx.Data["Title"] = ctx.Tr("sign_in")
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login"
Expand All @@ -184,12 +183,11 @@ func SignIn(ctx *context.Context) {
func SignInPost(ctx *context.Context) {
ctx.Data["Title"] = ctx.Tr("sign_in")

orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx)
oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue)
if err != nil {
ctx.ServerError("UserSignIn", err)
return
}
ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers
ctx.Data["Title"] = ctx.Tr("sign_in")
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login"
Expand Down Expand Up @@ -408,13 +406,12 @@ func SignUp(ctx *context.Context) {

ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/sign_up"

orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx)
oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue)
if err != nil {
ctx.ServerError("UserSignUp", err)
return
}

ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers
context.SetCaptchaData(ctx)

Expand All @@ -438,13 +435,12 @@ func SignUpPost(ctx *context.Context) {

ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/sign_up"

orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx)
oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue)
if err != nil {
ctx.ServerError("UserSignUp", err)
return
}

ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers
context.SetCaptchaData(ctx)

Expand Down
26 changes: 24 additions & 2 deletions routers/web/user/setting/security/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package security

import (
"net/http"
"sort"

auth_model "code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/services/auth/source/oauth2"
)

Expand Down Expand Up @@ -105,11 +107,31 @@ func loadSecurityData(ctx *context.Context) {
}
ctx.Data["AccountLinks"] = sources

orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx)
authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{
IsActive: util.OptionalBoolNone,
LoginType: auth_model.OAuth2,
})
if err != nil {
ctx.ServerError("GetActiveOAuth2Providers", err)
ctx.ServerError("FindSources", err)
return
}

var orderedOAuth2Names []string
oauth2Providers := make(map[string]oauth2.Provider)
for _, source := range authSources {
provider, err := oauth2.CreateProviderFromSource(source)
if err != nil {
ctx.ServerError("CreateProviderFromSource", err)
return
}
oauth2Providers[source.Name] = provider
if source.IsActive {
orderedOAuth2Names = append(orderedOAuth2Names, source.Name)
}
}

sort.Strings(orderedOAuth2Names)

ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers

Expand Down
5 changes: 4 additions & 1 deletion services/auth/signin.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/services/auth/source/oauth2"
"code.gitea.io/gitea/services/auth/source/smtp"

Expand Down Expand Up @@ -85,7 +86,9 @@ func UserSignIn(ctx context.Context, username, password string) (*user_model.Use
}
}

sources, err := auth.AllActiveSources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
})
if err != nil {
return nil, nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion services/auth/source/oauth2/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"

"github.com/google/uuid"
"github.com/gorilla/sessions"
Expand Down Expand Up @@ -63,7 +64,13 @@ func ResetOAuth2(ctx context.Context) error {

// initOAuth2Sources is used to load and register all active OAuth2 providers
func initOAuth2Sources(ctx context.Context) error {
authSources, _ := auth.GetActiveOAuth2ProviderSources(ctx)
authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
LoginType: auth.OAuth2,
})
if err != nil {
return err
}
for _, source := range authSources {
oauth2Source, ok := source.Cfg.(*Source)
if !ok {
Expand Down
Loading