Skip to content

Adding support for PostgreSQL as database #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions cmd/gonic/gonic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main

import (
"context"
"encoding/base64"
"errors"
"expvar"
"flag"
Expand All @@ -26,6 +27,7 @@ import (

"github.com/google/shlex"
"github.com/gorilla/securecookie"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/sentriz/gormstore"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -65,8 +67,7 @@ func main() {
flag.Var(&confMusicPaths, "music-path", "path to music")

confPlaylistsPath := flag.String("playlists-path", "", "path to your list of new or existing m3u playlists that gonic can manage")

confDBPath := flag.String("db-path", "gonic.db", "path to database (optional)")
confDBURI := flag.String("db-uri", "", "db URI")

confScanIntervalMins := flag.Uint("scan-interval", 0, "interval (in minutes) to automatically scan music (optional)")
confScanAtStart := flag.Bool("scan-at-start-enabled", false, "whether to perform an initial scan at startup (optional)")
Expand All @@ -92,6 +93,7 @@ func main() {
confExpvar := flag.Bool("expvar", false, "enable the /debug/vars endpoint (optional)")

deprecatedConfGenreSplit := flag.String("genre-split", "", "(deprecated, see multi-value settings)")
deprecatedConfDBPath := flag.String("db-path", "gonic.db", "(deprecated, see db-uri)")

flag.Parse()
flagconf.ParseEnv()
Expand Down Expand Up @@ -136,15 +138,18 @@ func main() {
log.Fatalf("couldn't create covers cache path: %v\n", err)
}

dbc, err := db.New(*confDBPath, db.DefaultOptions())
if *confDBURI == "" {
*confDBURI = "sqlite3://" + *deprecatedConfDBPath
}

dbc, err := db.New(*confDBURI)
if err != nil {
log.Fatalf("error opening database: %v\n", err)
}
defer dbc.Close()

err = dbc.Migrate(db.MigrationContext{
Production: true,
DBPath: *confDBPath,
OriginalMusicPath: confMusicPaths[0].path,
PlaylistsPath: *confPlaylistsPath,
PodcastsPath: *confPodcastPath,
Expand Down Expand Up @@ -225,17 +230,18 @@ func main() {
jukebx = jukebox.New()
}

sessKey, err := dbc.GetSetting("session_key")
encSessKey, err := dbc.GetSetting("session_key")
if err != nil {
log.Panicf("error getting session key: %v\n", err)
}
if sessKey == "" {
sessKey = string(securecookie.GenerateRandomKey(32))
if err := dbc.SetSetting("session_key", sessKey); err != nil {
sessKey, err := base64.StdEncoding.DecodeString(encSessKey)
if err != nil || len(sessKey) == 0 {
sessKey = securecookie.GenerateRandomKey(32)
if err := dbc.SetSetting("session_key", base64.StdEncoding.EncodeToString(sessKey)); err != nil {
log.Panicf("error setting session key: %v\n", err)
}
}
sessDB := gormstore.New(dbc.DB, []byte(sessKey))
sessDB := gormstore.New(dbc.DB, []byte(encSessKey))
sessDB.SessionOpts.HttpOnly = true
sessDB.SessionOpts.SameSite = http.SameSiteLaxMode

Expand Down
106 changes: 32 additions & 74 deletions db/db.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package db

import (
"context"
"errors"
"fmt"
"log"
Expand All @@ -13,55 +12,55 @@ import (
"time"

"github.com/jinzhu/gorm"
"github.com/mattn/go-sqlite3"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"

// TODO: remove this dep
"go.senan.xyz/gonic/server/ctrlsubsonic/specid"
)

func DefaultOptions() url.Values {
return url.Values{
// with this, multiple connections share a single data and schema cache.
// see https://www.sqlite.org/sharedcache.html
"cache": {"shared"},
// with this, the db sleeps for a little while when locked. can prevent
// a SQLITE_BUSY. see https://www.sqlite.org/c3ref/busy_timeout.html
"_busy_timeout": {"30000"},
"_journal_mode": {"WAL"},
"_foreign_keys": {"true"},
}
type DB struct {
*gorm.DB
}

func mockOptions() url.Values {
return url.Values{
"_foreign_keys": {"true"},
func New(uri string) (*DB, error) {
if uri == "" {
return nil, fmt.Errorf("empty db uri")
}
}

type DB struct {
*gorm.DB
}
url, err := url.Parse(uri)
if err != nil {
return nil, fmt.Errorf("parse uri: %w", err)
}

func New(path string, options url.Values) (*DB, error) {
// https://github.com/mattn/go-sqlite3#connection-string
url := url.URL{
Scheme: "file",
Opaque: path,
gormURL := strings.TrimPrefix(url.String(), url.Scheme+"://")

//nolint:goconst
switch url.Scheme {
case "sqlite3":
q := url.Query()
q.Set("cache", "shared")
q.Set("_busy_timeout", "30000")
q.Set("_journal_mode", "WAL")
q.Set("_foreign_keys", "true")
url.RawQuery = q.Encode()
case "postgres":
// the postgres driver expects the schema prefix to be on the URL
gormURL = url.String()
default:
return nil, fmt.Errorf("unknown db scheme")
}
url.RawQuery = options.Encode()
db, err := gorm.Open("sqlite3", url.String())

db, err := gorm.Open(url.Scheme, gormURL)
if err != nil {
return nil, fmt.Errorf("with gorm: %w", err)
}

db.SetLogger(log.New(os.Stdout, "gorm ", 0))
db.DB().SetMaxOpenConns(1)
return &DB{DB: db}, nil
}

func NewMock() (*DB, error) {
return New(":memory:", mockOptions())
}

func (db *DB) InsertBulkLeftMany(table string, head []string, left int, col []int) error {
if len(col) == 0 {
return nil
Expand All @@ -72,10 +71,11 @@ func (db *DB) InsertBulkLeftMany(table string, head []string, left int, col []in
rows = append(rows, "(?, ?)")
values = append(values, left, c)
}
q := fmt.Sprintf("INSERT OR IGNORE INTO %q (%s) VALUES %s",
q := fmt.Sprintf("INSERT INTO %q (%s) VALUES %s ON CONFLICT (%s) DO NOTHING",
table,
strings.Join(head, ", "),
strings.Join(rows, ", "),
strings.Join(head, ", "),
)
return db.Exec(q, values...).Error
}
Expand Down Expand Up @@ -611,45 +611,3 @@ func join[T fmt.Stringer](in []T, sep string) string {
}
return strings.Join(strs, sep)
}

func Dump(ctx context.Context, db *gorm.DB, to string) error {
dest, err := New(to, url.Values{})
if err != nil {
return fmt.Errorf("create dest db: %w", err)
}
defer dest.Close()

connSrc, err := db.DB().Conn(ctx)
if err != nil {
return fmt.Errorf("getting src raw conn: %w", err)
}
defer connSrc.Close()

connDest, err := dest.DB.DB().Conn(ctx)
if err != nil {
return fmt.Errorf("getting dest raw conn: %w", err)
}
defer connDest.Close()

err = connDest.Raw(func(connDest interface{}) error {
return connSrc.Raw(func(connSrc interface{}) error {
connDestq := connDest.(*sqlite3.SQLiteConn)
connSrcq := connSrc.(*sqlite3.SQLiteConn)
bk, err := connDestq.Backup("main", connSrcq, "main")
if err != nil {
return fmt.Errorf("create backup db: %w", err)
}
for done, _ := bk.Step(-1); !done; { //nolint: revive
}
if err := bk.Finish(); err != nil {
return fmt.Errorf("finishing dump: %w", err)
}
return nil
})
})
if err != nil {
return fmt.Errorf("backing up: %w", err)
}

return nil
}
2 changes: 1 addition & 1 deletion db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestGetSetting(t *testing.T) {
key := SettingKey(randKey())
value := "howdy"

testDB, err := NewMock()
testDB, err := New("sqlite3://:memory:")
if err != nil {
t.Fatalf("error creating db: %v", err)
}
Expand Down
58 changes: 35 additions & 23 deletions db/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package db

import (
"context"
"errors"
"fmt"
"log"
Expand All @@ -20,7 +19,6 @@ import (

type MigrationContext struct {
Production bool
DBPath string
OriginalMusicPath string
PlaylistsPath string
PodcastsPath string
Expand Down Expand Up @@ -59,7 +57,6 @@ func (db *DB) Migrate(ctx MigrationContext) error {
construct(ctx, "202206101425", migrateUser),
construct(ctx, "202207251148", migrateStarRating),
construct(ctx, "202211111057", migratePlaylistsQueuesToFullID),
constructNoTx(ctx, "202212272312", backupDBPre016),
construct(ctx, "202304221528", migratePlaylistsToM3U),
construct(ctx, "202305301718", migratePlayCountToLength),
construct(ctx, "202307281628", migrateAlbumArtistsMany2Many),
Expand Down Expand Up @@ -106,14 +103,14 @@ func constructNoTx(ctx MigrationContext, id string, f func(*gorm.DB, MigrationCo
func migrateInitSchema(tx *gorm.DB, _ MigrationContext) error {
return tx.AutoMigrate(
Genre{},
Artist{},
Album{},
Track{},
TrackGenre{},
AlbumGenre{},
Track{},
Artist{},
User{},
Setting{},
Play{},
Album{},
PlayQueue{},
).
Error
Expand Down Expand Up @@ -179,12 +176,18 @@ func migrateAddGenre(tx *gorm.DB, _ MigrationContext) error {

func migrateUpdateTranscodePrefIDX(tx *gorm.DB, _ MigrationContext) error {
var hasIDX int
tx.
Select("1").
Table("sqlite_master").
Where("type = ?", "index").
Where("name = ?", "idx_user_id_client").
Count(&hasIDX)
if tx.Dialect().GetName() == "sqlite3" {
tx.Select("1").
Table("sqlite_master").
Where("type = ?", "index").
Where("name = ?", "idx_user_id_client").
Count(&hasIDX)
} else if tx.Dialect().GetName() == "postgres" {
tx.Select("1").
Table("pg_indexes").
Where("indexname = ?", "idx_user_id_client").
Count(&hasIDX)
}
if hasIDX == 1 {
// index already exists
return nil
Expand Down Expand Up @@ -461,9 +464,15 @@ func migratePlaylistsQueuesToFullID(tx *gorm.DB, _ MigrationContext) error {
if err := step.Error; err != nil {
return fmt.Errorf("step migrate play_queues to full id: %w", err)
}
step = tx.Exec(`
if tx.Dialect().GetName() == "postgres" {
step = tx.Exec(`
UPDATE play_queues SET newcurrent=('tr-' || current)::varchar[200];
`)
} else {
step = tx.Exec(`
UPDATE play_queues SET newcurrent=('tr-' || CAST(current AS varchar(10)));
`)
}
if err := step.Error; err != nil {
return fmt.Errorf("step migrate play_queues to full id: %w", err)
}
Expand Down Expand Up @@ -590,7 +599,7 @@ func migrateAlbumArtistsMany2Many(tx *gorm.DB, _ MigrationContext) error {
return fmt.Errorf("step insert from albums: %w", err)
}

step = tx.Exec(`DROP INDEX idx_albums_tag_artist_id`)
step = tx.Exec(`DROP INDEX IF EXISTS idx_albums_tag_artist_id`)
if err := step.Error; err != nil {
return fmt.Errorf("step drop index: %w", err)
}
Expand Down Expand Up @@ -729,13 +738,6 @@ func migratePlaylistsPaths(tx *gorm.DB, ctx MigrationContext) error {
return nil
}

func backupDBPre016(tx *gorm.DB, ctx MigrationContext) error {
if !ctx.Production {
return nil
}
return Dump(context.Background(), tx, fmt.Sprintf("%s.%d.bak", ctx.DBPath, time.Now().Unix()))
}

func migrateAlbumTagArtistString(tx *gorm.DB, _ MigrationContext) error {
return tx.AutoMigrate(Album{}).Error
}
Expand Down Expand Up @@ -770,12 +772,22 @@ func migrateArtistAppearances(tx *gorm.DB, _ MigrationContext) error {
return fmt.Errorf("step transfer album artists: %w", err)
}

step = tx.Exec(`
if tx.Dialect().GetName() == "sqlite3" {
step = tx.Exec(`
INSERT OR IGNORE INTO artist_appearances (artist_id, album_id)
SELECT track_artists.artist_id, tracks.album_id
FROM track_artists
JOIN tracks ON tracks.id=track_artists.track_id
`)
} else {
step = tx.Exec(`
INSERT INTO artist_appearances (artist_id, album_id)
SELECT track_artists.artist_id, tracks.album_id
FROM track_artists
JOIN tracks ON tracks.id=track_artists.track_id
ON CONFLICT DO NOTHING
`)
}
if err := step.Error; err != nil {
return fmt.Errorf("step transfer album artists: %w", err)
}
Expand All @@ -795,7 +807,7 @@ func migrateTemporaryDisplayAlbumArtist(tx *gorm.DB, _ MigrationContext) error {
return tx.Exec(`
UPDATE albums
SET tag_album_artist=(
SELECT group_concat(artists.name, ', ')
SELECT string_agg(artists.name, ', ')
FROM artists
JOIN album_artists ON album_artists.artist_id=artists.id AND album_artists.album_id=albums.id
GROUP BY album_artists.album_id
Expand Down
Loading