Skip to content

Commit 66464e2

Browse files
Add basic test for hooks
1 parent 1df37c2 commit 66464e2

File tree

4 files changed

+101
-35
lines changed

4 files changed

+101
-35
lines changed

go/logic/hooks.go

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package logic
77

88
import (
99
"fmt"
10+
"io"
1011
"os"
1112
"os/exec"
1213
"path/filepath"
@@ -42,10 +43,6 @@ func NewHooksExecutor(migrationContext *base.MigrationContext) *HooksExecutor {
4243
}
4344
}
4445

45-
func (this *HooksExecutor) initHooks() error {
46-
return nil
47-
}
48-
4946
func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) []string {
5047
env := os.Environ()
5148
env = append(env, fmt.Sprintf("GH_OST_DATABASE_NAME=%s", this.migrationContext.DatabaseName))
@@ -76,13 +73,13 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [
7673
}
7774

7875
// executeHook executes a command, and sets relevant environment variables
79-
// combined output & error are printed to gh-ost's standard error.
80-
func (this *HooksExecutor) executeHook(hook string, extraVariables ...string) error {
76+
// combined output & error are printed to the provided io.Writer.
77+
func (this *HooksExecutor) executeHook(out io.Writer, hook string, extraVariables ...string) error {
8178
cmd := exec.Command(hook)
8279
cmd.Env = this.applyEnvironmentVariables(extraVariables...)
8380

8481
combinedOutput, err := cmd.CombinedOutput()
85-
fmt.Fprintln(os.Stderr, string(combinedOutput))
82+
fmt.Fprintln(out, string(combinedOutput))
8683
return log.Errore(err)
8784
}
8885

@@ -95,69 +92,69 @@ func (this *HooksExecutor) detectHooks(baseName string) (hooks []string, err err
9592
return hooks, err
9693
}
9794

98-
func (this *HooksExecutor) executeHooks(baseName string, extraVariables ...string) error {
95+
func (this *HooksExecutor) executeHooks(out io.Writer, baseName string, extraVariables ...string) error {
9996
hooks, err := this.detectHooks(baseName)
10097
if err != nil {
10198
return err
10299
}
103100
for _, hook := range hooks {
104101
log.Infof("executing %+v hook: %+v", baseName, hook)
105-
if err := this.executeHook(hook, extraVariables...); err != nil {
102+
if err := this.executeHook(out, hook, extraVariables...); err != nil {
106103
return err
107104
}
108105
}
109106
return nil
110107
}
111108

112109
func (this *HooksExecutor) onStartup() error {
113-
return this.executeHooks(onStartup)
110+
return this.executeHooks(os.Stderr, onStartup)
114111
}
115112

116113
func (this *HooksExecutor) onValidated() error {
117-
return this.executeHooks(onValidated)
114+
return this.executeHooks(os.Stderr, onValidated)
118115
}
119116

120117
func (this *HooksExecutor) onRowCountComplete() error {
121-
return this.executeHooks(onRowCountComplete)
118+
return this.executeHooks(os.Stderr, onRowCountComplete)
122119
}
123120
func (this *HooksExecutor) onBeforeRowCopy() error {
124-
return this.executeHooks(onBeforeRowCopy)
121+
return this.executeHooks(os.Stderr, onBeforeRowCopy)
125122
}
126123

127124
func (this *HooksExecutor) onRowCopyComplete() error {
128-
return this.executeHooks(onRowCopyComplete)
125+
return this.executeHooks(os.Stderr, onRowCopyComplete)
129126
}
130127

131128
func (this *HooksExecutor) onBeginPostponed() error {
132-
return this.executeHooks(onBeginPostponed)
129+
return this.executeHooks(os.Stderr, onBeginPostponed)
133130
}
134131

135132
func (this *HooksExecutor) onBeforeCutOver() error {
136-
return this.executeHooks(onBeforeCutOver)
133+
return this.executeHooks(os.Stderr, onBeforeCutOver)
137134
}
138135

139136
func (this *HooksExecutor) onInteractiveCommand(command string) error {
140137
v := fmt.Sprintf("GH_OST_COMMAND='%s'", command)
141-
return this.executeHooks(onInteractiveCommand, v)
138+
return this.executeHooks(os.Stderr, onInteractiveCommand, v)
142139
}
143140

144141
func (this *HooksExecutor) onSuccess() error {
145-
return this.executeHooks(onSuccess)
142+
return this.executeHooks(os.Stderr, onSuccess)
146143
}
147144

148145
func (this *HooksExecutor) onFailure() error {
149-
return this.executeHooks(onFailure)
146+
return this.executeHooks(os.Stderr, onFailure)
150147
}
151148

152149
func (this *HooksExecutor) onStatus(statusMessage string) error {
153150
v := fmt.Sprintf("GH_OST_STATUS='%s'", statusMessage)
154-
return this.executeHooks(onStatus, v)
151+
return this.executeHooks(os.Stderr, onStatus, v)
155152
}
156153

157154
func (this *HooksExecutor) onStopReplication() error {
158-
return this.executeHooks(onStopReplication)
155+
return this.executeHooks(os.Stderr, onStopReplication)
159156
}
160157

161158
func (this *HooksExecutor) onStartReplication() error {
162-
return this.executeHooks(onStartReplication)
159+
return this.executeHooks(os.Stderr, onStartReplication)
163160
}

go/logic/hooks_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
Copyright 2022 GitHub Inc.
3+
See https://github.com/github/gh-ost/blob/master/LICENSE
4+
*/
5+
6+
package logic
7+
8+
import (
9+
"bufio"
10+
"bytes"
11+
"fmt"
12+
"os"
13+
"path/filepath"
14+
"strings"
15+
"testing"
16+
17+
"github.com/openark/golib/tests"
18+
19+
"github.com/github/gh-ost/go/base"
20+
)
21+
22+
func TestHooksExecutorExecuteHooks(t *testing.T) {
23+
migrationContext := base.NewMigrationContext()
24+
migrationContext.AlterStatement = "ENGINE=InnoDB"
25+
migrationContext.DatabaseName = "test"
26+
migrationContext.OriginalTableName = "tablename"
27+
hooksExecutor := NewHooksExecutor(migrationContext)
28+
29+
t.Run("does-not-exist", func(t *testing.T) {
30+
migrationContext.HooksPath = "/does/not/exist"
31+
tests.S(t).ExpectNil(hooksExecutor.executeHooks(os.Stderr, "test-hook"))
32+
})
33+
34+
t.Run("failed", func(t *testing.T) {
35+
var err error
36+
if migrationContext.HooksPath, err = os.MkdirTemp("", "TestHooksExecutorExecuteHooks-failed"); err != nil {
37+
panic(err)
38+
}
39+
defer os.RemoveAll(migrationContext.HooksPath)
40+
41+
// write hook that fails with 'exit 1'
42+
hookFile := filepath.Join(migrationContext.HooksPath, "failed-hook")
43+
os.WriteFile(hookFile, []byte("#!/bin/sh\nexit 1"), 0777)
44+
tests.S(t).ExpectNotNil(hooksExecutor.executeHooks(os.Stderr, "failed-hook"))
45+
})
46+
47+
t.Run("success", func(t *testing.T) {
48+
var err error
49+
if migrationContext.HooksPath, err = os.MkdirTemp("", "TestHooksExecutorExecuteHooks-success"); err != nil {
50+
panic(err)
51+
}
52+
defer os.RemoveAll(migrationContext.HooksPath)
53+
54+
// write hook that prints the environment with 'env'
55+
hookFile := filepath.Join(migrationContext.HooksPath, "success-hook")
56+
os.WriteFile(hookFile, []byte("#!/bin/sh\nenv"), 0777)
57+
58+
// check output
59+
var buf bytes.Buffer
60+
tests.S(t).ExpectNil(hooksExecutor.executeHooks(&buf, "success-hook", "TEST="+t.Name()))
61+
scanner := bufio.NewScanner(&buf)
62+
for scanner.Scan() {
63+
split := strings.SplitN(scanner.Text(), "=", 2)
64+
switch split[0] {
65+
case "GH_OST_DATABASE_NAME":
66+
tests.S(t).ExpectEquals(split[1], migrationContext.DatabaseName)
67+
case "GH_OST_DDL":
68+
tests.S(t).ExpectEquals(split[1], migrationContext.AlterStatement)
69+
case "GH_OST_TABLE_NAME":
70+
tests.S(t).ExpectEquals(split[1], migrationContext.OriginalTableName)
71+
case "GH_OST_OLD_TABLE_NAME":
72+
tests.S(t).ExpectEquals(split[1], fmt.Sprintf("_%s_del", migrationContext.OriginalTableName))
73+
case "GH_OST_GHOST_TABLE_NAME":
74+
tests.S(t).ExpectEquals(split[1], fmt.Sprintf("_%s_gho", migrationContext.OriginalTableName))
75+
case "TEST":
76+
tests.S(t).ExpectEquals(split[1], t.Name())
77+
}
78+
}
79+
})
80+
}

go/logic/migrator.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ type Migrator struct {
9898
func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator {
9999
migrator := &Migrator{
100100
appVersion: appVersion,
101+
hooksExecutor: NewHooksExecutor(context),
101102
migrationContext: context,
102103
parser: sql.NewAlterTableParser(),
103104
ghostTableMigrated: make(chan bool),
@@ -113,15 +114,6 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator {
113114
return migrator
114115
}
115116

116-
// initiateHooksExecutor
117-
func (this *Migrator) initiateHooksExecutor() (err error) {
118-
this.hooksExecutor = NewHooksExecutor(this.migrationContext)
119-
if err := this.hooksExecutor.initHooks(); err != nil {
120-
return err
121-
}
122-
return nil
123-
}
124-
125117
// sleepWhileTrue sleeps indefinitely until the given function returns 'false'
126118
// (or fails with error)
127119
func (this *Migrator) sleepWhileTrue(operation func() (bool, error)) error {
@@ -342,9 +334,6 @@ func (this *Migrator) Migrate() (err error) {
342334

343335
go this.listenOnPanicAbort()
344336

345-
if err := this.initiateHooksExecutor(); err != nil {
346-
return err
347-
}
348337
if err := this.hooksExecutor.onStartup(); err != nil {
349338
return err
350339
}

script/test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ script/build
1414
cd .gopath/src/github.com/github/gh-ost
1515

1616
echo "Running unit tests"
17-
go test ./go/...
17+
go test -v -covermode=atomic ./go/...

0 commit comments

Comments
 (0)