Skip to content

Commit 455911c

Browse files
committed
Core: call callbacks on invalidated Client
Previously, invalidated Clients were removed from the Client cache without calling any of the registered ClientCallbacks. Now any callbacks registered with ClientCallbackOnCacheRemoval be invoked when a client is removed from cache.
1 parent 5630d8e commit 455911c

File tree

5 files changed

+323
-78
lines changed

5 files changed

+323
-78
lines changed

controllers/vaultdynamicsecret_controller.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,8 @@ func (r *VaultDynamicSecretReconciler) Reconcile(ctx context.Context, req ctrl.R
252252
} else if !vault.IsLeaseNotFoundError(err) {
253253
r.Recorder.Eventf(o, corev1.EventTypeWarning, consts.ReasonSecretLeaseRenewalError,
254254
"Could not renew lease, lease_id=%s, err=%s", leaseID, err)
255-
} else if !vault.IsForbiddenError(err) {
255+
} else if vault.IsForbiddenError(err) {
256+
logger.V(consts.LogLevelWarning).Info("Tainting client", "err", err)
256257
vClient.Taint()
257258
}
258259
syncReason = consts.ReasonSecretLeaseRenewalError
@@ -275,6 +276,10 @@ func (r *VaultDynamicSecretReconciler) Reconcile(ctx context.Context, req ctrl.R
275276
secretLease, staticCredsUpdated, err := r.syncSecret(ctx, vClient, o, transOption)
276277
if err != nil {
277278
r.SyncRegistry.Add(req.NamespacedName)
279+
if vault.IsForbiddenError(err) {
280+
logger.V(consts.LogLevelWarning).Info("Tainting client", "err", err)
281+
vClient.Taint()
282+
}
278283
entry, _ := r.BackOffRegistry.Get(req.NamespacedName)
279284
horizon := entry.NextBackOff()
280285
r.Recorder.Eventf(o, corev1.EventTypeWarning, consts.ReasonSecretSyncError,
@@ -504,7 +509,7 @@ func (r *VaultDynamicSecretReconciler) SetupWithManager(mgr ctrl.Manager, opts c
504509

505510
r.ClientFactory.RegisterClientCallbackHandler(
506511
vault.ClientCallbackHandler{
507-
On: vault.ClientCallbackOnLifetimeWatcherDone,
512+
On: vault.ClientCallbackOnLifetimeWatcherDone | vault.ClientCallbackOnCacheRemoval,
508513
Callback: r.vaultClientCallback,
509514
},
510515
)

internal/vault/client.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
type ClientOptions struct {
3030
SkipRenewal bool
31-
WatcherDoneCh chan<- Client
31+
WatcherDoneCh chan<- *ClientCallbackHandlerRequest
3232
}
3333

3434
func defaultClientOptions() *ClientOptions {
@@ -129,11 +129,10 @@ func NewClientFromStorageEntry(ctx context.Context, client ctrlclient.Client, en
129129
return nil, fmt.Errorf("restored client's cacheKey %s does not match expected %s", cacheKey, entry.CacheKey)
130130
}
131131

132-
if err := c.Validate(); err != nil {
133-
return nil, err
134-
}
132+
c.Taint()
133+
defer c.Untaint()
135134

136-
if _, err := c.Read(ctx, NewReadRequest("auth/token/lookup-self", nil)); err != nil {
135+
if err := c.Validate(ctx); err != nil {
137136
return nil, err
138137
}
139138

@@ -154,7 +153,7 @@ type Client interface {
154153
Restore(context.Context, *api.Secret) error
155154
GetTokenSecret() *api.Secret
156155
CheckExpiry(int64) (bool, error)
157-
Validate() error
156+
Validate(ctx context.Context) error
158157
GetVaultAuthObj() *secretsv1beta1.VaultAuth
159158
GetVaultConnectionObj() *secretsv1beta1.VaultConnection
160159
GetCredentialProvider() provider.CredentialProviderBase
@@ -184,7 +183,7 @@ type defaultClient struct {
184183
inClosing bool
185184
closed bool
186185
lastWatcherErr error
187-
watcherDoneCh chan<- Client
186+
watcherDoneCh chan<- *ClientCallbackHandlerRequest
188187
tainted bool
189188
once sync.Once
190189
mu sync.RWMutex
@@ -220,7 +219,7 @@ func (c *defaultClient) Taint() {
220219
// Validate the client, returning an error for any validation failures.
221220
// Typically, an invalid Client would be discarded and replaced with a new
222221
// instance.
223-
func (c *defaultClient) Validate() error {
222+
func (c *defaultClient) Validate(ctx context.Context) error {
224223
c.mu.RLock()
225224
defer c.mu.RUnlock()
226225

@@ -245,6 +244,16 @@ func (c *defaultClient) Validate() error {
245244
return errors.New("client token expired")
246245
}
247246

247+
if c.client == nil {
248+
return errors.New("client not set")
249+
}
250+
251+
if c.tainted {
252+
if _, err := c.Read(ctx, NewReadRequest("auth/token/lookup-self", nil)); err != nil {
253+
return fmt.Errorf("tainted client is invalid: %w", err)
254+
}
255+
}
256+
248257
return nil
249258
}
250259

@@ -492,7 +501,10 @@ func (c *defaultClient) startLifetimeWatcher(ctx context.Context) error {
492501
if c.watcherDoneCh != nil {
493502
if !c.inClosing {
494503
logger.V(consts.LogLevelTrace).Info("Writing to watcherDone channel")
495-
c.watcherDoneCh <- c
504+
c.watcherDoneCh <- &ClientCallbackHandlerRequest{
505+
Client: c,
506+
On: ClientCallbackOnLifetimeWatcherDone,
507+
}
496508
} else {
497509
logger.V(consts.LogLevelTrace).Info("In closing, not writing to watcherDone channel")
498510
}
@@ -759,12 +771,22 @@ func (c *defaultClient) init(ctx context.Context, client ctrlclient.Client,
759771
}
760772

761773
func (c *defaultClient) observeTime(ts time.Time, operation string) {
774+
if c.connObj == nil {
775+
// should not happen on a properly initialized Client
776+
return
777+
}
778+
762779
clientOperationTimes.WithLabelValues(operation, ctrlclient.ObjectKeyFromObject(c.connObj).String()).Observe(
763780
time.Since(ts).Seconds(),
764781
)
765782
}
766783

767784
func (c *defaultClient) incrementOperationCounter(operation string, err error) {
785+
if c.connObj == nil {
786+
// should not happen on a properly initialized Client
787+
return
788+
}
789+
768790
vaultConn := ctrlclient.ObjectKeyFromObject(c.connObj).String()
769791
clientOperations.WithLabelValues(operation, vaultConn).Inc()
770792
if err != nil {

internal/vault/client_factory.go

Lines changed: 104 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,34 @@ import (
2828
)
2929

3030
// ClientCallbackOn is an enumeration of possible client callback events.
31-
type ClientCallbackOn int
31+
type ClientCallbackOn uint32
3232

3333
const (
34+
NamePrefixVCC = "vso-cc-"
35+
3436
// ClientCallbackOnLifetimeWatcherDone is a ClientCallbackOn that handles client
3537
// lifetime watcher done events.
36-
ClientCallbackOnLifetimeWatcherDone ClientCallbackOn = iota
37-
NamePrefixVCC = "vso-cc-"
38+
ClientCallbackOnLifetimeWatcherDone ClientCallbackOn = 1 << iota
39+
// ClientCallbackOnCacheRemoval is a ClientCallbackOn that handles client cache removal events.
40+
ClientCallbackOnCacheRemoval
3841
)
3942

43+
func (o ClientCallbackOn) String() string {
44+
switch o {
45+
case ClientCallbackOnLifetimeWatcherDone:
46+
return "LifetimeWatcherDone"
47+
case ClientCallbackOnCacheRemoval:
48+
return "CacheRemoval"
49+
default:
50+
return "Unknown"
51+
}
52+
}
53+
54+
type ClientCallbackHandlerRequest struct {
55+
On ClientCallbackOn
56+
Client Client
57+
}
58+
4059
// ClientCallback is a function type that takes a context, a Client, and an error as parameters.
4160
// It is used in the context of a ClientCallbackHandler.
4261
type ClientCallback func(ctx context.Context, c Client)
@@ -99,7 +118,7 @@ type cachingClientFactory struct {
99118
pruneStorageOnEvict bool
100119
ctrlClient ctrlclient.Client
101120
clientCallbacks []ClientCallbackHandler
102-
callbackHandlerCh chan Client
121+
callbackHandlerCh chan *ClientCallbackHandlerRequest
103122
mu sync.RWMutex
104123
onceDoWatcher sync.Once
105124
callbackHandlerCancel context.CancelFunc
@@ -182,7 +201,10 @@ func (m *cachingClientFactory) prune(ctx context.Context, client ctrlclient.Clie
182201
if !skipCallbacks {
183202
for _, c := range pruned {
184203
// the callback handler will remove the client from the storage
185-
m.callbackHandlerCh <- c
204+
m.callbackHandlerCh <- &ClientCallbackHandlerRequest{
205+
On: ClientCallbackOnCacheRemoval,
206+
Client: c,
207+
}
186208
}
187209
} else {
188210
// for all cache entries pruned, remove the corresponding storage entries.
@@ -404,38 +426,31 @@ func (m *cachingClientFactory) Get(ctx context.Context, client ctrlclient.Client
404426
if ok {
405427
// return the Client from the cache if it is still Valid
406428
tainted = c.Tainted()
407-
logger.V(consts.LogLevelTrace).Info("Got client from cache", "clientID", c.ID(), "tainted", tainted)
408-
if tainted {
409-
// if the Client is tainted, we need to validate its token.
410-
if _, err := c.Read(ctx, NewReadRequest("auth/token/lookup-self", nil)); err == nil {
411-
defer c.Untaint()
412-
tainted = false
413-
if err := c.Validate(); err == nil {
414-
return namespacedClient(c)
415-
}
429+
logger.V(consts.LogLevelTrace).Info("Got client from cache",
430+
"clientID", c.ID(), "tainted", tainted)
431+
if err := c.Validate(ctx); err != nil {
432+
logger.V(consts.LogLevelDebug).Error(err, "Invalid client",
433+
"tainted", tainted)
434+
m.callbackHandlerCh <- &ClientCallbackHandlerRequest{
435+
On: ClientCallbackOnCacheRemoval,
436+
Client: c,
416437
}
417-
} else if err := c.Validate(); err == nil {
438+
} else {
439+
c.Untaint()
418440
return namespacedClient(c)
419441
}
420-
421-
logger.V(consts.LogLevelDebug).Error(err, "Invalid client",
422-
"tainted", tainted)
423-
424-
// remove the parent Client from the cache in order to prune any of its clones.
425-
m.cache.Remove(cacheKey)
426442
} else {
427443
logger.V(consts.LogLevelTrace).Info("Client not found in cache", "cacheKey", fmt.Sprintf("%#v", cacheKey))
428-
}
429-
430-
if !ok && m.storageEnabled() {
431-
// try and restore from Client storage cache, if properly configured to do so.
432-
restored, err := m.restoreClientFromCacheKey(ctx, client, cacheKey)
433-
if restored != nil {
434-
return namespacedClient(restored)
435-
}
444+
if m.storageEnabled() {
445+
// try and restore from Client storage cache, if properly configured to do so.
446+
restored, err := m.restoreClientFromCacheKey(ctx, client, cacheKey)
447+
if restored != nil {
448+
return namespacedClient(restored)
449+
}
436450

437-
if !IsStorageEntryNotFoundErr(err) {
438-
logger.Error(err, "Failed to restore client from storage")
451+
if !IsStorageEntryNotFoundErr(err) {
452+
logger.Error(err, "Failed to restore client from storage")
453+
}
439454
}
440455
}
441456

@@ -642,7 +657,7 @@ func (m *cachingClientFactory) storageEncryptionClient(ctx context.Context, clie
642657
// ensure that the cached Vault Client is not expired, and if it is then call storageEncryptionClient() again.
643658
// This operation should be safe since we are setting m.clientCacheKeyEncrypt to empty string,
644659
// so there should be no risk of causing a maximum recursion error.
645-
if reason := c.Validate(); reason != nil {
660+
if reason := c.Validate(ctx); reason != nil {
646661
m.logger.V(consts.LogLevelWarning).Info("Restored Vault client is invalid, recreating it",
647662
"cacheKey", m.clientCacheKeyEncrypt, "reason", reason)
648663

@@ -677,7 +692,7 @@ func (m *cachingClientFactory) startClientCallbackHandler(ctx context.Context) {
677692

678693
go func() {
679694
if m.callbackHandlerCh == nil {
680-
m.callbackHandlerCh = make(chan Client)
695+
m.callbackHandlerCh = make(chan *ClientCallbackHandlerRequest)
681696
}
682697
defer func() {
683698
close(m.callbackHandlerCh)
@@ -689,16 +704,20 @@ func (m *cachingClientFactory) startClientCallbackHandler(ctx context.Context) {
689704
case <-callbackCtx.Done():
690705
logger.Info("Client callback handler done")
691706
return
692-
case c, stillOpen := <-m.callbackHandlerCh:
707+
case req, stillOpen := <-m.callbackHandlerCh:
693708
if !stillOpen {
694709
logger.Info("Client callback handler channel closed")
695710
return
696711
}
697-
if c.IsClone() {
712+
if req == nil {
698713
continue
699714
}
700715

701-
cacheKey, err := c.GetCacheKey()
716+
if req.Client.IsClone() {
717+
continue
718+
}
719+
720+
cacheKey, err := req.Client.GetCacheKey()
702721
if err != nil {
703722
logger.Error(err, "Invalid client, client callbacks not executed",
704723
"cacheKey", cacheKey)
@@ -708,30 +727,62 @@ func (m *cachingClientFactory) startClientCallbackHandler(ctx context.Context) {
708727
// remove the client from the cache, it will be recreated when a reconciler
709728
// requests it.
710729
logger.V(consts.LogLevelDebug).Info("Removing client from cache", "cacheKey", cacheKey)
711-
m.cache.Remove(cacheKey)
712-
if m.storageEnabled() {
713-
if _, err := m.pruneStorage(ctx, m.ctrlClient, cacheKey); err != nil {
714-
logger.Info("Warning: failed to prune storage", "cacheKey", cacheKey)
730+
if req.On&ClientCallbackOnLifetimeWatcherDone != 0 {
731+
m.cache.Remove(cacheKey)
732+
if m.storageEnabled() {
733+
if _, err := m.pruneStorage(ctx, m.ctrlClient, cacheKey); err != nil {
734+
logger.Info("Warning: failed to prune storage", "cacheKey", cacheKey)
735+
}
715736
}
716737
}
717738

718-
for idx, cbReq := range m.clientCallbacks {
719-
if cbReq.On != ClientCallbackOnLifetimeWatcherDone {
720-
continue
721-
}
722-
723-
logger.Info("Calling client callback on lifetime watcher done",
724-
"index", idx, "cacheKey", cacheKey, "clientID", c.ID())
725-
// call in a go routine to avoid blocking the channel
726-
go func(cbReq ClientCallbackHandler) {
727-
cbReq.Callback(ctx, c)
728-
}(cbReq)
729-
}
739+
m.callClientCallbacks(ctx, req.Client, req.On, false)
730740
}
731741
}
732742
}()
733743
}
734744

745+
// callClientCallbacks calls all registered client callbacks for the specified
746+
// event. If wait is true, it will block until all callbacks have been executed.
747+
// Note: wait is only for testing purposes.
748+
func (m *cachingClientFactory) callClientCallbacks(ctx context.Context, c Client, on ClientCallbackOn, wait bool) {
749+
logger := log.FromContext(ctx).WithName("callClientCallbacks")
750+
751+
var cbs []ClientCallbackHandler
752+
for _, cbReq := range m.clientCallbacks {
753+
x := on & cbReq.On
754+
if x != 0 {
755+
cbs = append(cbs, cbReq)
756+
continue
757+
}
758+
}
759+
760+
if len(cbs) == 0 {
761+
return
762+
}
763+
764+
var wg sync.WaitGroup
765+
if wait {
766+
wg.Add(len(cbs))
767+
}
768+
769+
for idx, cbReq := range cbs {
770+
logger.Info("Calling client callback",
771+
"index", idx, "clientID", c.ID(), "on", on)
772+
// call in a go routine to avoid blocking the channel
773+
go func(cbReq ClientCallbackHandler) {
774+
if wait {
775+
defer wg.Done()
776+
}
777+
cbReq.Callback(ctx, c)
778+
}(cbReq)
779+
}
780+
781+
if wait {
782+
wg.Wait()
783+
}
784+
}
785+
735786
// NewCachingClientFactory returns a CachingClientFactory with ClientCache initialized.
736787
// The ClientCache's onEvictCallback is registered with the factory's onClientEvict(),
737788
// to ensure any evictions are handled by the factory (this is very important).
@@ -741,7 +792,7 @@ func NewCachingClientFactory(ctx context.Context, client ctrlclient.Client, cach
741792
recorder: config.Recorder,
742793
persist: config.Persist,
743794
ctrlClient: client,
744-
callbackHandlerCh: make(chan Client),
795+
callbackHandlerCh: make(chan *ClientCallbackHandlerRequest),
745796
encryptionRequired: config.StorageConfig.EnforceEncryption,
746797
clientLocks: make(map[ClientCacheKey]*sync.RWMutex, config.ClientCacheSize),
747798
logger: zap.New().WithName("clientCacheFactory").WithValues(

0 commit comments

Comments
 (0)