Skip to content

Commit 5caa89c

Browse files
committed
fix restore command to support documented URL patterns
Signed-off-by: Youngmin Koo <[email protected]>
1 parent cbc934e commit 5caa89c

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

cmd/restore.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command,
2828
PreRun: func(cmd *cobra.Command, args []string) {
2929
bindFlags(cmd, v)
3030
},
31-
Args: cobra.MinimumNArgs(1),
31+
Args: cobra.RangeArgs(0, 1),
3232
RunE: func(cmd *cobra.Command, args []string) error {
3333
cmdConfig.logger.Debug("starting restore")
3434
ctx := context.Background()
@@ -40,8 +40,20 @@ func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command,
4040
}()
4141
ctx = util.ContextWithTracer(ctx, tracer)
4242
_, startupSpan := tracer.Start(ctx, "startup")
43-
targetFile := args[0]
44-
target := v.GetString("target")
43+
44+
// Get target from args[0], --target flag, or DB_RESTORE_TARGET environment variable
45+
var target string
46+
if len(args) > 0 {
47+
target = args[0]
48+
} else {
49+
target = v.GetString("target")
50+
}
51+
if target == "" {
52+
return fmt.Errorf("target must be specified as argument, --target flag, or DB_RESTORE_TARGET environment variable")
53+
}
54+
55+
// Always pass empty targetFile to use the full path from the URL
56+
targetFile := ""
4557
// get databases namesand mappings
4658
databasesMap := make(map[string]string)
4759
databases := strings.TrimSpace(v.GetString("database"))
@@ -144,9 +156,6 @@ func restoreCmd(passedExecs execs, cmdConfig *cmdConfiguration) (*cobra.Command,
144156

145157
flags := cmd.Flags()
146158
flags.String("target", "", "full URL target to the backup that you wish to restore")
147-
if err := cmd.MarkFlagRequired("target"); err != nil {
148-
return nil, err
149-
}
150159

151160
// compression
152161
flags.String("compression", defaultCompression, "Compression to use. Supported are: `gzip`, `bzip2`, `none`")

pkg/storage/s3/s3.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ func (s *S3) Pull(ctx context.Context, source, target string, logger *log.Entry)
7373
return 0, fmt.Errorf("failed to get AWS client: %v", err)
7474
}
7575

76-
bucket, path := s.url.Hostname(), path.Join(s.url.Path, source)
76+
bucket := s.url.Hostname()
77+
// If source is empty, use the path from URL directly (for restore command)
78+
// Otherwise, append source to the URL path (for dump command)
79+
var objectPath string
80+
if source == "" {
81+
objectPath = strings.TrimPrefix(s.url.Path, "/")
82+
} else {
83+
objectPath = strings.TrimPrefix(path.Join(s.url.Path, source), "/")
84+
}
7785

7886
// Create a downloader with the session and default options
7987
downloader := manager.NewDownloader(client)
@@ -88,7 +96,7 @@ func (s *S3) Pull(ctx context.Context, source, target string, logger *log.Entry)
8896
// Write the contents of S3 Object to the file
8997
n, err := downloader.Download(context.TODO(), f, &s3.GetObjectInput{
9098
Bucket: aws.String(bucket),
91-
Key: aws.String(path),
99+
Key: aws.String(objectPath),
92100
})
93101
if err != nil {
94102
return 0, fmt.Errorf("failed to download file, %v", err)

0 commit comments

Comments
 (0)