Skip to content

Commit c283978

Browse files
authored
Merge pull request #29 from github/cache-locking
Implement some primitive cache locking.
2 parents 41ceecb + 0725f13 commit c283978

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

internal/cachedirectory/cachedirectory.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,43 @@ func (cacheDirectory *CacheDirectory) CheckOrCreateVersionFile(pull bool, versio
103103
return errors.New(errorPushNonCache)
104104
}
105105

106+
func (cacheDirectory *CacheDirectory) Lock() error {
107+
file, err := os.Create(cacheDirectory.lockFilePath())
108+
if err != nil {
109+
return errors.Wrap(err, "Error locking cache directory.")
110+
}
111+
defer file.Close()
112+
// If the cache directory is already locked, it's not really a huge issue since the purpose of the lock is mostly to check whether a `pull` operation was interrupted before pushing.
113+
return nil
114+
}
115+
116+
func (cacheDirectory *CacheDirectory) Unlock() error {
117+
err := os.Remove(cacheDirectory.lockFilePath())
118+
if err != nil {
119+
return errors.Wrap(err, "Error unlocking cache directory.")
120+
}
121+
return nil
122+
}
123+
124+
func (cacheDirectory *CacheDirectory) CheckLock() error {
125+
_, err := os.Stat(cacheDirectory.lockFilePath())
126+
if err == nil {
127+
return errors.New("The cache directory is locked, likely due to a `pull` command being interrupted. Please run `pull` again to ensure all required data is downloaded.")
128+
}
129+
if os.IsNotExist(err) {
130+
return nil
131+
}
132+
return errors.Wrap(err, "Error checking if cache directory is locked.")
133+
}
134+
106135
func (cacheDirectory *CacheDirectory) versionFilePath() string {
107136
return path.Join(cacheDirectory.path, ".codeql-actions-sync-version")
108137
}
109138

139+
func (cacheDirectory *CacheDirectory) lockFilePath() string {
140+
return path.Join(cacheDirectory.path, ".codeql-actions-sync-lock")
141+
}
142+
110143
func (cacheDirectory *CacheDirectory) GitPath() string {
111144
return path.Join(cacheDirectory.path, "git")
112145
}

internal/cachedirectory/cachedirectory_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,14 @@ func TestCreateCacheDirectoryWithTrailingSlash(t *testing.T) {
8888
err := cacheDirectory.CheckOrCreateVersionFile(true, aVersion)
8989
require.NoError(t, err)
9090
}
91+
92+
func TestLocking(t *testing.T) {
93+
temporaryDirectory := test.CreateTemporaryDirectory(t)
94+
cacheDirectory := NewCacheDirectory(path.Join(temporaryDirectory, "cache"))
95+
require.NoError(t, cacheDirectory.CheckOrCreateVersionFile(true, aVersion))
96+
require.NoError(t, cacheDirectory.Lock())
97+
require.NoError(t, cacheDirectory.Lock())
98+
require.Error(t, cacheDirectory.CheckLock())
99+
require.NoError(t, cacheDirectory.Unlock())
100+
require.NoError(t, cacheDirectory.CheckLock())
101+
}

internal/pull/pull.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ func Pull(ctx context.Context, cacheDirectory cachedirectory.CacheDirectory, sou
237237
if err != nil {
238238
return err
239239
}
240+
err = cacheDirectory.Lock()
241+
if err != nil {
242+
return err
243+
}
240244

241245
var tokenClient *http.Client
242246
if sourceToken != "" {
@@ -266,6 +270,11 @@ func Pull(ctx context.Context, cacheDirectory cachedirectory.CacheDirectory, sou
266270
if err != nil {
267271
return err
268272
}
273+
274+
err = cacheDirectory.Unlock()
275+
if err != nil {
276+
return err
277+
}
269278
log.Print("Finished pulling the CodeQL Action repository and bundles!")
270279
return nil
271280
}

internal/push/push.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ func Push(ctx context.Context, cacheDirectory cachedirectory.CacheDirectory, des
291291
if err != nil {
292292
return err
293293
}
294+
err = cacheDirectory.CheckLock()
295+
if err != nil {
296+
return err
297+
}
294298

295299
destinationURL = strings.TrimRight(destinationURL, "/")
296300
tokenSource := oauth2.StaticTokenSource(

0 commit comments

Comments
 (0)