Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions xormigrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@ const (
initSchemaMigrationId = "SCHEMA_INIT"
)

// MigrateFunc is the func signature for migratinx.
// MigrateFunc is the func signature for migrating.
type MigrateFunc func(*xorm.Engine) error

// RollbackFunc is the func signature for rollbackinx.
// RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*xorm.Engine) error

// InitSchemaFunc is the func signature for initializing the schema.
type InitSchemaFunc func(*xorm.Engine) error

// MigrateFunc is the func signature for migrating.
type MigrateFuncSession func(*xorm.Session) error

// RollbackFunc is the func signature for rollbacking.
type RollbackFuncSession func(*xorm.Session) error

// Migration represents a database migration (a modification to be made on the database).
type Migration struct {
// ID is the migration identifier. Usually a timestamp like "201601021504".
ID string `xorm:"id"`
// Description is the migration description, which is optionally printed out when the migration is ran.
Description string
// Migrate is a function that will br executed while running this migration.
// Migrate is a function that will be executed while running this migration.
Migrate MigrateFunc `xorm:"-"`
// Rollback will be executed on rollback. Can be nil.
Rollback RollbackFunc `xorm:"-"`
// MigrateSession is a function that will be executed while running this migration, using xorm.Session.
MigrateSession MigrateFuncSession `xorm:"-"`
// RollbackSession will be executed on rollback, using xorm.Session. Can be nil.
RollbackSession RollbackFuncSession `xorm:"-"`
// Long marks the migration an non-required migration that will likely take a long time. Must use Xormigrate.AllowLong() to be enabled.
Long bool `xorm:"-"`
}
Expand Down Expand Up @@ -251,14 +261,25 @@ func (x *Xormigrate) RollbackMigration(m *Migration) error {
}

func (x *Xormigrate) rollbackMigration(m *Migration) error {
if m.Rollback == nil {
if m.Rollback == nil && m.RollbackSession == nil {
return ErrRollbackImpossible
}
if len(m.Description) > 0 {
logger.Errorf("Rolling back migration: %s", m.Description)
}
if err := m.Rollback(x.db); err != nil {
return err
if m.Rollback != nil {
if err := m.Rollback(x.db); err != nil {
return err
}
} else {
sess := x.db.NewSession()
if err := m.RollbackSession(sess); err != nil {
rollbackSession(sess)
return err
}
if err := sess.Commit(); err != nil {
return err
}
}
if _, err := x.db.In("id", m.ID).Delete(&Migration{}); err != nil {
return err
Expand All @@ -268,7 +289,12 @@ func (x *Xormigrate) rollbackMigration(m *Migration) error {

func (x *Xormigrate) runInitSchema() error {
logger.Info("Initializing Schema")
sess := x.db.NewSession()
if err := x.initSchema(x.db); err != nil {
rollbackSession(sess)
return err
}
if err := sess.Commit(); err != nil {
return err
}
if err := x.insertMigration(initSchemaMigrationId); err != nil {
Expand All @@ -293,8 +319,19 @@ func (x *Xormigrate) runMigration(migration *Migration) error {
if len(migration.Description) > 0 {
logger.Info(migration.Description)
}
if err := migration.Migrate(x.db); err != nil {
return fmt.Errorf("migration %s failed: %s", migration.ID, err.Error())
if migration.Migrate != nil {
if err := migration.Migrate(x.db); err != nil {
return fmt.Errorf("migration %s failed: %s", migration.ID, err.Error())
}
} else {
sess := x.db.NewSession()
if err := migration.MigrateSession(sess); err != nil {
rollbackSession(sess)
return fmt.Errorf("migration %s failed: %s", migration.ID, err.Error())
}
if err := sess.Commit(); err != nil {
return err
}
}

if err := x.insertMigration(migration.ID); err != nil {
Expand Down Expand Up @@ -339,3 +376,9 @@ func (x *Xormigrate) insertMigration(id string) error {
_, err := x.db.Insert(&Migration{ID: id})
return err
}

func rollbackSession(sess *xorm.Session) {
if err := sess.Rollback(); err != nil {
logger.Errorf("Failed to rollback session: %v", err)
}
}
58 changes: 58 additions & 0 deletions xormigrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ var migrations = []*Migration{
},
}

var migrationsSession = []*Migration{
{
ID: "201608301400",
Description: "Add Person",
MigrateSession: func(tx *xorm.Session) error {
return tx.Sync(&Person{})
},
RollbackSession: func(tx *xorm.Session) error {
return tx.DropTable(&Person{})
},
},
{
ID: "201608301430",
MigrateSession: func(tx *xorm.Session) error {
return tx.Sync2(&Pet{})
},
RollbackSession: func(tx *xorm.Session) error {
return tx.DropTable(&Pet{})
},
},
}

var extendedMigrations = append(migrations, &Migration{
ID: "201807221927",
Migrate: func(tx *xorm.Engine) error {
Expand Down Expand Up @@ -355,6 +377,42 @@ func TestAllowLong(t *testing.T) {
})
}

func TestMigrationSession(t *testing.T) {
forEachDatabase(t, func(db *xorm.Engine) {
m := New(db, migrationsSession)

err := m.Migrate()
assert.NoError(t, err)
has, err := db.IsTableExist(&Person{})
assert.NoError(t, err)
assert.True(t, has)
has, err = db.IsTableExist(&Pet{})
assert.NoError(t, err)
assert.True(t, has)
assert.Equal(t, int64(2), tableCount(t, db))

err = m.RollbackLast()
assert.NoError(t, err)
has, err = db.IsTableExist(&Person{})
assert.NoError(t, err)
assert.True(t, has)
has, err = db.Exist(&Pet{})
assert.Error(t, err)
assert.False(t, has)
assert.Equal(t, int64(1), tableCount(t, db))

err = m.RollbackLast()
assert.NoError(t, err)
has, err = db.IsTableExist(&Person{})
assert.NoError(t, err)
assert.False(t, has)
has, err = db.IsTableExist(&Pet{})
assert.NoError(t, err)
assert.False(t, has)
assert.Equal(t, int64(0), tableCount(t, db))
})
}

func tableCount(t *testing.T, db *xorm.Engine) (count int64) {
count, err := db.Count(&Migration{})
assert.NoError(t, err)
Expand Down