Skip to content

Commit fe08bed

Browse files
committed
feat(distributed): acquire and auto-refresh worker NATS credentials
Workers fetched NATS credentials once at startup, which broke two cases under JWT auth: a worker that registered while still pending admin approval never received a minted JWT (it connected unauthenticated and gave up), and a long-running worker's 24h JWT expired with no way to renew it. Introduce workerregistry.NATSCredentialManager, built on idempotent re-registration (the frontend preserves the node row and mints a fresh JWT each call): - Acquire re-registers through admin approval until the node is approved and credentials are minted (or returns the first success when auth is not required, preserving anonymous-NATS behavior). - RefreshLoop re-registers before the JWT expires (~75% of its lifetime), updating the credentials served to the connection. - Both are bounded (default 100 attempts / consecutive failures) and return an error on exhaustion, so an unapprovable or unrenewable worker exits non-zero and surfaces the problem instead of hanging or drifting toward an expired credential. The messaging client gains WithUserJWTProvider, fetching credentials on each (re)connect so the connection transparently adopts a refreshed JWT when the server expires the old one. RegisterFull exposes the approval status and full response; Register delegates to it. Both the backend worker and the agent worker are wired to this: explicit env credentials are used as-is, minted credentials are acquired-with-wait and refreshed, and a permanent refresh failure shuts the worker down so it restarts and re-acquires. Tests cover Acquire (wait-through-pending, bounded give-up, context cancel), RefreshLoop (refresh-before-expiry, bounded failure, no-expiry exit) and jwtExpiry decoding. Docs updated in distributed-mode.md. Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com>
1 parent 33f601d commit fe08bed

6 files changed

Lines changed: 542 additions & 35 deletions

File tree

core/cli/agent_worker.go

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,30 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
9090
registrationBody["token"] = cmd.RegistrationToken
9191
}
9292

93-
nodeID, apiToken, regNatsJWT, regNatsSeed, err := regClient.RegisterWithRetry(context.Background(), registrationBody, 10)
93+
// Context cancelled on shutdown — used by registration waits, heartbeat, and
94+
// other background goroutines.
95+
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
96+
defer shutdownCancel()
97+
98+
// Acquire credentials via (re)registration. When the bus requires auth and no
99+
// static fallback is configured, wait through admin approval until the
100+
// frontend mints credentials rather than starting unauthenticated.
101+
credMgr := workerregistry.NewNATSCredentialManager(
102+
func(ctx context.Context) (*workerregistry.RegisterResponse, error) {
103+
return regClient.RegisterFull(ctx, registrationBody)
104+
},
105+
cmd.NatsRequireAuth && cmd.NatsJWT == "" && cmd.NatsServiceJWT == "",
106+
)
107+
res, err := credMgr.Acquire(shutdownCtx)
94108
if err != nil {
95109
return fmt.Errorf("registration failed: %w", err)
96110
}
111+
nodeID := res.ID
97112
xlog.Info("Registered with frontend", "nodeID", nodeID, "frontend", cmd.RegisterTo)
98113

99114
// Use provisioned API token if none was set
100115
if cmd.APIToken == "" {
101-
cmd.APIToken = apiToken
116+
cmd.APIToken = res.APIToken
102117
}
103118

104119
// Start heartbeat
@@ -107,22 +122,35 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
107122
xlog.Warn("invalid heartbeat interval, using default 10s", "input", cmd.HeartbeatInterval, "error", err)
108123
}
109124
heartbeatInterval = cmp.Or(heartbeatInterval, 10*time.Second)
110-
// Context cancelled on shutdown — used by heartbeat and other background goroutines
111-
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
112-
defer shutdownCancel()
113125

114126
go regClient.HeartbeatLoop(shutdownCtx, nodeID, heartbeatInterval, func() map[string]any { return map[string]any{} })
115127

116-
// Connect to NATS
117-
natsJWT := cmp.Or(cmd.NatsJWT, regNatsJWT, cmd.NatsServiceJWT)
118-
natsSeed := cmp.Or(cmd.NatsUserSeed, regNatsSeed, cmd.NatsServiceSeed)
119-
if cmd.NatsRequireAuth && (natsJWT == "" || natsSeed == "") {
120-
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
121-
}
128+
// Resolve NATS credentials with precedence: explicit env override, then
129+
// frontend-minted (auto-refreshed before expiry), then service fallback.
130+
// Each static source must supply JWT and seed together.
122131
natsTLS := messaging.TLSFiles{CA: cmd.NatsTLSCA, Cert: cmd.NatsTLSCert, Key: cmd.NatsTLSKey}
123132
var natsOpts []messaging.Option
124-
if natsJWT != "" && natsSeed != "" {
125-
natsOpts = append(natsOpts, messaging.WithUserJWT(natsJWT, natsSeed))
133+
switch {
134+
case cmd.NatsJWT != "" || cmd.NatsUserSeed != "":
135+
if (cmd.NatsJWT == "") != (cmd.NatsUserSeed == "") {
136+
return fmt.Errorf("LOCALAI_NATS_JWT and LOCALAI_NATS_USER_SEED must be set together")
137+
}
138+
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsJWT, cmd.NatsUserSeed))
139+
case credMgr.HasCredentials():
140+
natsOpts = append(natsOpts, messaging.WithUserJWTProvider(credMgr.Provider()))
141+
go func() {
142+
if err := credMgr.RefreshLoop(shutdownCtx); err != nil {
143+
xlog.Error("NATS credential refresh permanently failed; shutting down agent worker", "error", err)
144+
shutdownCancel()
145+
}
146+
}()
147+
case cmd.NatsServiceJWT != "" || cmd.NatsServiceSeed != "":
148+
if (cmd.NatsServiceJWT == "") != (cmd.NatsServiceSeed == "") {
149+
return fmt.Errorf("LOCALAI_NATS_SERVICE_JWT and LOCALAI_NATS_SERVICE_SEED must be set together")
150+
}
151+
natsOpts = append(natsOpts, messaging.WithUserJWT(cmd.NatsServiceJWT, cmd.NatsServiceSeed))
152+
case cmd.NatsRequireAuth:
153+
return fmt.Errorf("NATS JWT+seed required: enable frontend minting or set LOCALAI_NATS_* env vars")
126154
}
127155
if natsTLS.Enabled() {
128156
natsOpts = append(natsOpts, messaging.WithTLS(natsTLS))
@@ -205,17 +233,25 @@ func (cmd *AgentWorkerCMD) Run(ctx *cliContext.Context) error {
205233

206234
xlog.Info("Agent worker ready, waiting for jobs", "subject", cmd.Subject, "queue", cmd.Queue)
207235

208-
// Wait for shutdown
236+
// Wait for an OS signal or an internal fatal condition (e.g. NATS
237+
// credentials became unrenewable), so the worker restarts and re-acquires
238+
// rather than lingering unable to serve.
209239
sigCh := make(chan os.Signal, 1)
210240
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
211-
<-sigCh
241+
var runErr error
242+
select {
243+
case <-sigCh:
244+
case <-shutdownCtx.Done():
245+
runErr = fmt.Errorf("agent worker shutting down: NATS credentials unavailable")
246+
xlog.Error("Internal shutdown requested", "error", runErr)
247+
}
212248

213249
xlog.Info("Shutting down agent worker")
214250
shutdownCancel() // stop heartbeat loop immediately
215251
dispatcher.Stop()
216252
mcpTools.CloseAllMCPSessions()
217253
regClient.GracefulDeregister(nodeID)
218-
return nil
254+
return runErr
219255
}
220256

221257
// handleMCPToolRequest handles a NATS request-reply for MCP tool execution.

core/cli/workerregistry/client.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,40 +58,53 @@ func (c *RegistrationClient) setAuth(req *http.Request) {
5858

5959
// RegisterResponse is the JSON body returned by /api/node/register.
6060
type RegisterResponse struct {
61-
ID string `json:"id"`
61+
ID string `json:"id"`
62+
Status string `json:"status,omitempty"` // "pending" until an admin approves the node
6263
APIToken string `json:"api_token,omitempty"`
6364
NatsJWT string `json:"nats_jwt,omitempty"`
6465
NatsUserSeed string `json:"nats_user_seed,omitempty"`
6566
}
6667

67-
// Register sends a single registration request and returns the node ID and
68-
// optional credentials (API token for agent workers, NATS JWT when configured).
69-
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
68+
// RegisterFull sends a single registration request and returns the full
69+
// response (node ID, approval status, and optional API token / NATS creds).
70+
// Re-registration is idempotent: the frontend preserves the node row and mints
71+
// a fresh NATS JWT each call, so this doubles as the credential-refresh call.
72+
func (c *RegistrationClient) RegisterFull(ctx context.Context, body map[string]any) (*RegisterResponse, error) {
7073
jsonBody, _ := json.Marshal(body)
7174
url := c.baseURL() + "/api/node/register"
7275

7376
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonBody))
7477
if err != nil {
75-
return "", "", "", "", fmt.Errorf("creating request: %w", err)
78+
return nil, fmt.Errorf("creating request: %w", err)
7679
}
7780
req.Header.Set("Content-Type", "application/json")
7881
c.setAuth(req)
7982

8083
resp, err := c.httpClient().Do(req)
8184
if err != nil {
82-
return "", "", "", "", fmt.Errorf("posting to %s: %w", url, err)
85+
return nil, fmt.Errorf("posting to %s: %w", url, err)
8386
}
8487
defer resp.Body.Close()
8588

8689
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
87-
return "", "", "", "", fmt.Errorf("registration failed with status %d", resp.StatusCode)
90+
return nil, fmt.Errorf("registration failed with status %d", resp.StatusCode)
8891
}
8992

9093
var result RegisterResponse
9194
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
92-
return "", "", "", "", fmt.Errorf("decoding response: %w", err)
95+
return nil, fmt.Errorf("decoding response: %w", err)
96+
}
97+
return &result, nil
98+
}
99+
100+
// Register sends a single registration request and returns the node ID and
101+
// optional credentials (API token for agent workers, NATS JWT when configured).
102+
func (c *RegistrationClient) Register(ctx context.Context, body map[string]any) (nodeID, apiToken, natsJWT, natsSeed string, err error) {
103+
res, err := c.RegisterFull(ctx, body)
104+
if err != nil {
105+
return "", "", "", "", err
93106
}
94-
return result.ID, result.APIToken, result.NatsJWT, result.NatsUserSeed, nil
107+
return res.ID, res.APIToken, res.NatsJWT, res.NatsUserSeed, nil
95108
}
96109

97110
// RegisterWithRetry retries registration with exponential backoff.
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
package workerregistry
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/mudler/LocalAI/pkg/natsauth"
10+
"github.com/mudler/xlog"
11+
)
12+
13+
// statusPending mirrors nodes.StatusPending. It is duplicated rather than
14+
// imported so the lightweight registration client does not pull in the nodes
15+
// package (and its gorm/DB dependencies).
16+
const statusPending = "pending"
17+
18+
// defaultMaxAttempts bounds how many times Acquire registers (and how many
19+
// consecutive times RefreshLoop may fail) before giving up. It is high enough
20+
// to ride out a slow admin approval or a transient frontend outage, but finite
21+
// so an unauthorized/unapprovable worker exits and surfaces the problem (via a
22+
// non-zero exit and the resulting restart) rather than waiting forever.
23+
const defaultMaxAttempts = 100
24+
25+
// RegisterFunc performs one idempotent registration round-trip.
26+
type RegisterFunc func(ctx context.Context) (*RegisterResponse, error)
27+
28+
// NATSCredentialManager acquires NATS credentials at startup — waiting through
29+
// admin approval when required — and refreshes them before the minted JWT
30+
// expires, by re-registering (which mints a fresh JWT). The live NATS
31+
// connection adopts a refreshed JWT on its next reconnect via Provider. Safe
32+
// for concurrent use.
33+
//
34+
// It addresses two failure modes: a worker that needs credentials but registers
35+
// while still pending approval (it would otherwise give up and never connect),
36+
// and a long-running worker whose 24h JWT expires with no way to renew it.
37+
type NATSCredentialManager struct {
38+
register RegisterFunc
39+
requireCreds bool // block until credentials are present (frontend minting in use)
40+
41+
// Tunables; defaults set by NewNATSCredentialManager, overridable in tests.
42+
initialBackoff time.Duration
43+
maxBackoff time.Duration
44+
maxAttempts int // bound on Acquire attempts / consecutive refresh failures (<=0 = unlimited)
45+
refreshLead float64 // refresh once this fraction of the JWT lifetime has elapsed
46+
refreshRetry time.Duration
47+
expiryOf func(jwt string) (time.Time, bool)
48+
49+
mu sync.RWMutex
50+
jwt string
51+
seed string
52+
nodeID string
53+
}
54+
55+
// NewNATSCredentialManager builds a manager over register. When requireCreds is
56+
// true, Acquire blocks until the node is approved and credentials are minted.
57+
func NewNATSCredentialManager(register RegisterFunc, requireCreds bool) *NATSCredentialManager {
58+
return &NATSCredentialManager{
59+
register: register,
60+
requireCreds: requireCreds,
61+
initialBackoff: 2 * time.Second,
62+
maxBackoff: 30 * time.Second,
63+
maxAttempts: defaultMaxAttempts,
64+
refreshLead: 0.75,
65+
refreshRetry: 30 * time.Second,
66+
expiryOf: jwtExpiry,
67+
}
68+
}
69+
70+
// jwtExpiry decodes the expiry of a minted user JWT. ok is false when the token
71+
// is empty/undecodable or carries no expiry (e.g. a non-expiring service JWT).
72+
func jwtExpiry(token string) (time.Time, bool) {
73+
if token == "" {
74+
return time.Time{}, false
75+
}
76+
uc, err := natsauth.DecodeUserClaims(token)
77+
if err != nil || uc.Expires == 0 {
78+
return time.Time{}, false
79+
}
80+
return time.Unix(uc.Expires, 0), true
81+
}
82+
83+
func (m *NATSCredentialManager) store(res *RegisterResponse) {
84+
m.mu.Lock()
85+
defer m.mu.Unlock()
86+
m.nodeID = res.ID
87+
if res.NatsJWT != "" && res.NatsUserSeed != "" {
88+
m.jwt, m.seed = res.NatsJWT, res.NatsUserSeed
89+
}
90+
}
91+
92+
// Current returns the latest NATS credentials (both empty until acquired).
93+
func (m *NATSCredentialManager) Current() (jwt, seed string) {
94+
m.mu.RLock()
95+
defer m.mu.RUnlock()
96+
return m.jwt, m.seed
97+
}
98+
99+
// NodeID returns the node ID from the most recent registration.
100+
func (m *NATSCredentialManager) NodeID() string {
101+
m.mu.RLock()
102+
defer m.mu.RUnlock()
103+
return m.nodeID
104+
}
105+
106+
// Provider returns a callback compatible with messaging.WithUserJWTProvider,
107+
// supplying the current credentials on each (re)connect.
108+
func (m *NATSCredentialManager) Provider() func() (string, string) {
109+
return m.Current
110+
}
111+
112+
// HasCredentials reports whether complete NATS credentials have been obtained.
113+
func (m *NATSCredentialManager) HasCredentials() bool {
114+
jwt, seed := m.Current()
115+
return jwt != "" && seed != ""
116+
}
117+
118+
// Acquire registers and, when requireCreds is set, keeps re-registering with
119+
// exponential backoff until the node is approved (status != pending) and
120+
// credentials are minted. Without requireCreds it returns the first successful
121+
// response (the historical one-shot behavior, preserved for anonymous NATS).
122+
func (m *NATSCredentialManager) Acquire(ctx context.Context) (*RegisterResponse, error) {
123+
backoff := m.initialBackoff
124+
var lastReason error
125+
for attempt := 1; m.maxAttempts <= 0 || attempt <= m.maxAttempts; attempt++ {
126+
res, err := m.register(ctx)
127+
switch {
128+
case err != nil:
129+
lastReason = err
130+
xlog.Warn("Registration failed, retrying", "attempt", attempt, "next_retry", backoff, "error", err)
131+
case !m.requireCreds:
132+
m.store(res)
133+
return res, nil
134+
case res.Status == statusPending:
135+
lastReason = fmt.Errorf("node %s still pending admin approval", res.ID)
136+
xlog.Info("Node pending admin approval; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
137+
case res.NatsJWT == "" || res.NatsUserSeed == "":
138+
lastReason = fmt.Errorf("node %s approved but NATS credentials not minted", res.ID)
139+
xlog.Info("Node approved but NATS credentials not yet minted; waiting", "node", res.ID, "attempt", attempt, "next_retry", backoff)
140+
default:
141+
m.store(res)
142+
return res, nil
143+
}
144+
select {
145+
case <-ctx.Done():
146+
return nil, ctx.Err()
147+
case <-time.After(backoff):
148+
}
149+
backoff = min(backoff*2, m.maxBackoff)
150+
}
151+
return nil, fmt.Errorf("giving up acquiring NATS credentials after %d attempts: %w", m.maxAttempts, lastReason)
152+
}
153+
154+
// RefreshLoop re-registers to mint a fresh JWT before the current one expires,
155+
// updating the credentials returned by Current/Provider so the NATS connection
156+
// adopts them on its next reconnect. It returns nil when ctx is cancelled or
157+
// when the current credential has no expiry (nothing to refresh), and a non-nil
158+
// error after maxAttempts consecutive refresh failures — letting the caller
159+
// exit the worker so it restarts and re-acquires (or surfaces the outage)
160+
// rather than silently drifting toward an expired, unrenewable JWT.
161+
func (m *NATSCredentialManager) RefreshLoop(ctx context.Context) error {
162+
failures := 0
163+
for {
164+
jwt, _ := m.Current()
165+
exp, ok := m.expiryOf(jwt)
166+
if !ok {
167+
xlog.Debug("NATS credential has no expiry; refresh loop exiting")
168+
return nil
169+
}
170+
wait := max(time.Duration(float64(time.Until(exp))*m.refreshLead), 0)
171+
select {
172+
case <-ctx.Done():
173+
return nil
174+
case <-time.After(wait):
175+
}
176+
177+
res, err := m.register(ctx)
178+
if err == nil && res.NatsJWT != "" && res.NatsUserSeed != "" {
179+
m.store(res)
180+
failures = 0
181+
xlog.Info("Refreshed NATS credentials", "node", res.ID)
182+
continue
183+
}
184+
failures++
185+
if err != nil {
186+
xlog.Warn("NATS credential refresh failed; will retry", "attempt", failures, "error", err)
187+
} else {
188+
xlog.Warn("NATS credential refresh returned no credentials; will retry", "attempt", failures)
189+
}
190+
if m.maxAttempts > 0 && failures >= m.maxAttempts {
191+
return fmt.Errorf("NATS credential refresh failed %d times in a row", failures)
192+
}
193+
// Back off before retrying so a persistent failure near expiry does not spin.
194+
select {
195+
case <-ctx.Done():
196+
return nil
197+
case <-time.After(m.refreshRetry):
198+
}
199+
}
200+
}

0 commit comments

Comments
 (0)