Skip to content

Commit c9d2566

Browse files
authored
Make column duplication idempotent (#946)
This PR makes column duplication idempotent. Previously, if the same column was altered in multiple operations in the same migration was not working. The modified column could not be duplicated in the second operation because the `ALTER TABLE` statement has failed. I added `IF NOT EXISTS` to make the column creation idempotent. Now the column is duplicated once, and other changes are added on top of it like setting not null constraints, changing default values, etc. The only edge case this PR does not address is when the column type is changed in a later operation. This limitation can be worked around by changing the type first in an `alter_column` operation and then adding other changes to the migration. I do opted for this trade-off because it would increase the complexity of the duplicator but not much of an upside. This is working: ```yaml operations: - alter_column: name: mycol type: text - alter_column: name: mycol comment: nocomment ``` This is not supported: ```yaml operations: - alter_column: name: mycol comment: nocomment - alter_column: name: mycol type: text ```
1 parent b884976 commit c9d2566

File tree

4 files changed

+130
-7
lines changed

4 files changed

+130
-7
lines changed

pkg/backfill/backfill.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ func (j *Job) AddTask(t *Task) {
7575
// CASE WHEN "review" = 'bad' THEN 'bad review' ELSE 'good review' END, it must be rewritten to:
7676
// CASE WHEN NEW."_pgroll_new_review" = 'bad' THEN 'bad review' ELSE 'good review' END.
7777
// Otherwise, the trigger will not work correctly because it will reference the old column name.
78-
tg.SQL = append(tg.SQL, rewriteTriggerSQL(trigger.SQL, findColumnName(tg.Columns, tg.PhysicalColumn), tg.PhysicalColumn))
78+
tg.SQL = append(tg.SQL,
79+
rewriteTriggerSQL(trigger.SQL, findColumnName(trigger.Columns, trigger.PhysicalColumn), trigger.PhysicalColumn),
80+
)
7981
j.triggers[trigger.Name] = tg
8082
} else {
8183
// If the trigger does not exist, create a new trigger config

pkg/migrations/duplicate.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (d *duplicator) Execute(ctx context.Context) error {
172172
func (d *duplicatorStmtBuilder) duplicateCheckConstraints(withoutConstraint []string, colNames ...string) []string {
173173
stmts := make([]string, 0, len(d.table.CheckConstraints))
174174
for _, cc := range d.table.CheckConstraints {
175-
if slices.Contains(withoutConstraint, cc.Name) {
175+
if slices.Contains(withoutConstraint, cc.Name) || IsDuplicatedName(cc.Name) {
176176
continue
177177
}
178178
if duplicatedConstraintColumns := d.duplicatedConstraintColumns(cc.Columns, colNames...); len(duplicatedConstraintColumns) > 0 {
@@ -214,7 +214,7 @@ func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint
214214
func (d *duplicatorStmtBuilder) duplicateIndexes(withoutConstraint []string, colNames ...string) []string {
215215
stmts := make([]string, 0, len(d.table.Indexes))
216216
for _, idx := range d.table.Indexes {
217-
if slices.Contains(withoutConstraint, idx.Name) {
217+
if slices.Contains(withoutConstraint, idx.Name) || IsDuplicatedName(idx.Name) {
218218
continue
219219
}
220220
if _, ok := d.table.UniqueConstraints[idx.Name]; ok && idx.Unique {
@@ -286,8 +286,8 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
286286
withType string,
287287
) string {
288288
const (
289-
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
290-
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
289+
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN IF NOT EXISTS %s %s`
290+
cAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
291291
)
292292

293293
// Generate SQL to duplicate the column's name and type
@@ -299,10 +299,22 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
299299
// Generate SQL to add an unchecked NOT NULL constraint if the original column
300300
// is NOT NULL. The constraint will be validated on migration completion.
301301
if !column.Nullable && !withoutNotNull {
302-
sql += fmt.Sprintf(", "+cAddCheckConstraintSQL,
303-
pq.QuoteIdentifier(DuplicationName(NotNullConstraintName(column.Name))),
302+
constraintName := DuplicationName(NotNullConstraintName(column.Name))
303+
if _, ok := d.table.CheckConstraints[constraintName]; ok {
304+
return sql // Skip if the constraint already exists
305+
}
306+
sql += fmt.Sprintf("; "+cAddCheckConstraintSQL,
307+
pq.QuoteIdentifier(d.table.Name),
308+
pq.QuoteIdentifier(constraintName),
304309
fmt.Sprintf("CHECK (%s IS NOT NULL)", pq.QuoteIdentifier(asName)),
305310
)
311+
if d.table.CheckConstraints == nil {
312+
d.table.CheckConstraints = make(map[string]*schema.CheckConstraint)
313+
}
314+
d.table.CheckConstraints[constraintName] = &schema.CheckConstraint{
315+
Name: constraintName,
316+
Columns: []string{asName},
317+
}
306318
}
307319

308320
return sql

pkg/migrations/op_alter_column.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,13 @@ func (o *OpAlterColumn) Start(ctx context.Context, l Logger, conn db.DB, s *sche
6363
// a variable for the new column. Save the old column name for use as the
6464
// physical column name. in the down trigger first.
6565
oldPhysicalColumn := column.Name
66+
columnType := column.Type
67+
if o.Type != nil {
68+
columnType = *o.Type
69+
}
6670
table.AddColumn(o.Column, &schema.Column{
6771
Name: TemporaryName(o.Column),
72+
Type: columnType,
6873
})
6974

7075
// Add a trigger to copy values from the new column to the old.

pkg/migrations/op_alter_column_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,110 @@ func TestAlterColumnMultipleSubOperations(t *testing.T) {
505505
}, rows)
506506
},
507507
},
508+
{
509+
name: "can alter a column multiple times: add check and not null",
510+
migrations: []migrations.Migration{
511+
{
512+
Name: "01_create_table",
513+
Operations: migrations.Operations{
514+
&migrations.OpCreateTable{
515+
Name: "people",
516+
Columns: []migrations.Column{
517+
{
518+
Name: "id",
519+
Type: "serial",
520+
Pk: true,
521+
},
522+
{
523+
Name: "name",
524+
Type: "varchar(255)",
525+
Nullable: true,
526+
},
527+
},
528+
},
529+
},
530+
},
531+
{
532+
Name: "02_alter_column",
533+
Operations: migrations.Operations{
534+
&migrations.OpAlterColumn{
535+
Table: "people",
536+
Column: "name",
537+
Up: "INITCAP(name)",
538+
Down: "name",
539+
Check: &migrations.CheckConstraint{
540+
Name: "name_capitalized",
541+
Constraint: "name ~ '^[A-Z].*'",
542+
},
543+
},
544+
&migrations.OpAlterColumn{
545+
Table: "people",
546+
Column: "name",
547+
Up: "SELECT CASE WHEN name IS NULL THEN 'Jane' ELSE name END",
548+
Down: "name",
549+
Nullable: ptr(false),
550+
},
551+
},
552+
},
553+
},
554+
afterStart: func(t *testing.T, db *sql.DB, schema string) {
555+
// Can insert a row into the `people` table with a lowercase `name` value
556+
MustInsert(t, db, schema, "01_create_table", "people", map[string]string{
557+
"name": "alice",
558+
})
559+
// Can insert a row into the `people` table with a NULL `name` value
560+
MustInsert(t, db, schema, "01_create_table", "people", map[string]string{
561+
"id": "2",
562+
})
563+
// Cannot insert a row into the `people` table with an lowercase `name` value
564+
MustNotInsert(t, db, schema, "02_alter_column", "people", map[string]string{
565+
"name": "bob",
566+
}, testutils.CheckViolationErrorCode)
567+
// Can insert a row into the `people` table with a missing `name` value
568+
MustNotInsert(t, db, schema, "02_alter_column", "people", map[string]string{
569+
"id": "4",
570+
}, testutils.CheckViolationErrorCode)
571+
// Can insert a row into the `people` table with a capitalized `name` value
572+
MustInsert(t, db, schema, "01_create_table", "people", map[string]string{
573+
"id": "4",
574+
"name": "Carl",
575+
})
576+
577+
// The version of the `people` table in the new schema has the expected rows.
578+
rows := MustSelect(t, db, schema, "02_alter_column", "people")
579+
assert.Equal(t, []map[string]any{
580+
{"id": 1, "name": "Alice"},
581+
{"id": 2, "name": "Jane"},
582+
{"id": 4, "name": "Carl"},
583+
}, rows)
584+
},
585+
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
586+
// Can insert a row into the `people` table with a lowercase `name` value
587+
MustInsert(t, db, schema, "01_create_table", "people", map[string]string{
588+
"name": "alice",
589+
})
590+
// Can insert a row into the `people` table with a NULL `name` value
591+
MustInsert(t, db, schema, "01_create_table", "people", map[string]string{
592+
"id": "100",
593+
})
594+
},
595+
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
596+
// Inserting a row into the `people` table with a NULL `manages` value fails
597+
MustNotInsert(t, db, schema, "02_alter_column", "people", map[string]string{
598+
"name": "carl",
599+
}, testutils.CheckViolationErrorCode)
600+
601+
// The `people` table has the expected rows.
602+
rows := MustSelect(t, db, schema, "02_alter_column", "people")
603+
assert.Equal(t, []map[string]any{
604+
{"id": 1, "name": "Alice"},
605+
{"id": 2, "name": "Jane"},
606+
{"id": 4, "name": "Carl"},
607+
{"id": 3, "name": "Alice"},
608+
{"id": 100, "name": "Jane"},
609+
}, rows)
610+
},
611+
},
508612
})
509613
}
510614

0 commit comments

Comments
 (0)