diff --git a/components/gitpod-db/go/personal_access_token.go b/components/gitpod-db/go/personal_access_token.go index 3a41de5baae0df..b93937ca5f1a6a 100644 --- a/components/gitpod-db/go/personal_access_token.go +++ b/components/gitpod-db/go/personal_access_token.go @@ -38,14 +38,25 @@ func (d *PersonalAccessToken) TableName() string { func GetPersonalAccessTokenForUser(ctx context.Context, conn *gorm.DB, tokenID uuid.UUID, userID uuid.UUID) (PersonalAccessToken, error) { var token PersonalAccessToken - db := conn.WithContext(ctx) + if tokenID == uuid.Nil { + return PersonalAccessToken{}, fmt.Errorf("Token ID is a required argument to get personal access token for user") + } - db = db.Where("id = ?", tokenID).Where("userId = ?", userID).Where("deleted = ?", 0).First(&token) - if db.Error != nil { - if errors.Is(db.Error, gorm.ErrRecordNotFound) { + if userID == uuid.Nil { + return PersonalAccessToken{}, fmt.Errorf("User ID is a required argument to get personal access token for user") + } + + tx := conn. + WithContext(ctx). + Where("id = ?", tokenID). + Where("userId = ?", userID). + Where("deleted = ?", 0). + First(&token) + if tx.Error != nil { + if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return PersonalAccessToken{}, fmt.Errorf("Token with ID %s does not exist: %w", tokenID, ErrorNotFound) } - return PersonalAccessToken{}, fmt.Errorf("Failed to retrieve token: %v", db.Error) + return PersonalAccessToken{}, fmt.Errorf("Failed to retrieve token: %v", tx.Error) } return token, nil diff --git a/components/gitpod-db/go/personal_access_token_test.go b/components/gitpod-db/go/personal_access_token_test.go index 803e31235e227a..f7aed73b411608 100644 --- a/components/gitpod-db/go/personal_access_token_test.go +++ b/components/gitpod-db/go/personal_access_token_test.go @@ -29,6 +29,16 @@ func TestPersonalAccessToken_Get(t *testing.T) { dbtest.CreatePersonalAccessTokenRecords(t, conn, tokenEntries...) + t.Run("nil token ID is rejected", func(t *testing.T) { + _, err := db.GetPersonalAccessTokenForUser(context.Background(), conn, uuid.Nil, token.UserID) + require.Error(t, err) + }) + + t.Run("nil user ID is rejected", func(t *testing.T) { + _, err := db.GetPersonalAccessTokenForUser(context.Background(), conn, token.ID, uuid.Nil) + require.Error(t, err) + }) + t.Run("not matching user", func(t *testing.T) { _, err := db.GetPersonalAccessTokenForUser(context.Background(), conn, token.ID, token2.UserID) require.Error(t, err, db.ErrorNotFound)