1
1
use crate :: config:: get_config;
2
2
use crate :: errors:: Error ;
3
3
use arc_swap:: ArcSwap ;
4
- use log:: { debug, error, info} ;
4
+ use log:: { debug, error, info, warn } ;
5
5
use once_cell:: sync:: Lazy ;
6
6
use std:: collections:: { HashMap , HashSet } ;
7
7
use std:: io;
8
8
use std:: net:: IpAddr ;
9
9
use std:: sync:: Arc ;
10
10
use std:: sync:: RwLock ;
11
11
use tokio:: time:: { sleep, Duration } ;
12
- use trust_dns_resolver:: error:: ResolveResult ;
12
+ use trust_dns_resolver:: error:: { ResolveError , ResolveResult } ;
13
13
use trust_dns_resolver:: lookup_ip:: LookupIp ;
14
14
use trust_dns_resolver:: TokioAsyncResolver ;
15
15
16
16
/// Cached Resolver Globally available
17
- pub static CACHED_RESOLVER : Lazy < ArcSwap < Option < ArcSwap < CachedResolver > > > > =
18
- Lazy :: new ( || ArcSwap :: from_pointee ( None ) ) ;
17
+ pub static CACHED_RESOLVER : Lazy < ArcSwap < CachedResolver > > =
18
+ Lazy :: new ( || ArcSwap :: from_pointee ( CachedResolver :: default ( ) ) ) ;
19
19
20
20
// Ip addressed are returned as a set of addresses
21
21
// so we can compare.
@@ -57,8 +57,8 @@ impl From<LookupIp> for AddrSet {
57
57
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
58
58
///
59
59
/// # tokio_test::block_on(async {
60
- /// let config = CachedResolverConfig{dns_max_ttl: 10} ;
61
- /// let resolver = CachedResolver::new(config).await.unwrap();
60
+ /// let config = CachedResolverConfig::default() ;
61
+ /// let resolver = CachedResolver::new(config, None ).await.unwrap();
62
62
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
63
63
/// # })
64
64
/// ```
@@ -70,23 +70,45 @@ impl From<LookupIp> for AddrSet {
70
70
/// // You can now check if an 'old' lookup differs from what it's currently
71
71
/// // store in cache by using `has_changed`.
72
72
/// resolver.has_changed("www.example.com.", addrset)
73
+ #[ derive( Default ) ]
73
74
pub struct CachedResolver {
74
75
// The configuration of the cached_resolver.
75
76
config : CachedResolverConfig ,
76
77
77
78
// This is the hash that contains the hash.
78
- data : Arc < RwLock < HashMap < String , AddrSet > > > ,
79
+ data : Option < RwLock < HashMap < String , AddrSet > > > ,
79
80
80
81
// The resolver to be used for DNS queries.
81
- resolver : Arc < TokioAsyncResolver > ,
82
+ resolver : Option < TokioAsyncResolver > ,
83
+
84
+ // The RefreshLoop
85
+ refresh_loop : RwLock < Option < tokio:: task:: JoinHandle < ( ) > > > ,
82
86
}
83
87
84
88
///
85
89
/// Configuration
86
- #[ derive( Clone , Debug ) ]
90
+ #[ derive( Clone , Debug , Default , PartialEq ) ]
87
91
pub struct CachedResolverConfig {
88
92
/// Amount of time in secods that a resolved dns address is considered stale.
89
- pub dns_max_ttl : u64 ,
93
+ dns_max_ttl : u64 ,
94
+
95
+ /// Enabled or disabled? (this is so we can reload config)
96
+ enabled : bool ,
97
+ }
98
+
99
+ impl CachedResolverConfig {
100
+ fn new ( dns_max_ttl : u64 , enabled : bool ) -> Self {
101
+ CachedResolverConfig {
102
+ dns_max_ttl,
103
+ enabled,
104
+ }
105
+ }
106
+ }
107
+
108
+ impl From < crate :: config:: Config > for CachedResolverConfig {
109
+ fn from ( config : crate :: config:: Config ) -> Self {
110
+ CachedResolverConfig :: new ( config. general . dns_max_ttl , config. general . dns_cache_enabled )
111
+ }
90
112
}
91
113
92
114
impl CachedResolver {
@@ -104,29 +126,47 @@ impl CachedResolver {
104
126
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
105
127
///
106
128
/// # tokio_test::block_on(async {
107
- /// let config = CachedResolverConfig{dns_max_ttl: 10} ;
108
- /// let resolver = CachedResolver::new(config);
129
+ /// let config = CachedResolverConfig::default() ;
130
+ /// let resolver = CachedResolver::new(config, None).await.unwrap( );
109
131
/// # })
110
132
/// ```
111
133
///
112
- pub async fn new ( config : CachedResolverConfig ) -> io:: Result < Arc < Self > > {
134
+ pub async fn new (
135
+ config : CachedResolverConfig ,
136
+ data : Option < HashMap < String , AddrSet > > ,
137
+ ) -> Result < Arc < Self > , io:: Error > {
113
138
// Construct a new Resolver with default configuration options
114
- let resolver = Arc :: new ( TokioAsyncResolver :: tokio_from_system_conf ( ) ?) ;
115
- let data = Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ;
139
+ let resolver = Some ( TokioAsyncResolver :: tokio_from_system_conf ( ) ?) ;
140
+
141
+ let data = if let Some ( hash) = data {
142
+ Some ( RwLock :: new ( hash) )
143
+ } else {
144
+ Some ( RwLock :: new ( HashMap :: new ( ) ) )
145
+ } ;
116
146
117
- let self_ref = Arc :: new ( Self {
147
+ let instance = Arc :: new ( Self {
118
148
config,
119
149
resolver,
120
150
data,
151
+ refresh_loop : RwLock :: new ( None ) ,
121
152
} ) ;
122
- let clone_self_ref = self_ref. clone ( ) ;
123
153
124
- info ! ( "Scheduling DNS refresh loop" ) ;
125
- tokio:: task:: spawn ( async move {
126
- clone_self_ref. refresh_dns_entries_loop ( ) . await ;
127
- } ) ;
154
+ if instance. enabled ( ) {
155
+ info ! ( "Scheduling DNS refresh loop" ) ;
156
+ let refresh_loop = tokio:: task:: spawn ( {
157
+ let instance = instance. clone ( ) ;
158
+ async move {
159
+ instance. refresh_dns_entries_loop ( ) . await ;
160
+ }
161
+ } ) ;
162
+ * ( instance. refresh_loop . write ( ) . unwrap ( ) ) = Some ( refresh_loop) ;
163
+ }
128
164
129
- Ok ( self_ref)
165
+ Ok ( instance)
166
+ }
167
+
168
+ pub fn enabled ( & self ) -> bool {
169
+ self . config . enabled
130
170
}
131
171
132
172
// Schedules the refresher
@@ -139,8 +179,10 @@ impl CachedResolver {
139
179
// an array with keys.
140
180
let mut hostnames: Vec < String > = Vec :: new ( ) ;
141
181
{
142
- for hostname in self . data . read ( ) . unwrap ( ) . keys ( ) {
143
- hostnames. push ( hostname. clone ( ) ) ;
182
+ if let Some ( ref data) = self . data {
183
+ for hostname in data. read ( ) . unwrap ( ) . keys ( ) {
184
+ hostnames. push ( hostname. clone ( ) ) ;
185
+ }
144
186
}
145
187
}
146
188
@@ -193,8 +235,8 @@ impl CachedResolver {
193
235
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
194
236
///
195
237
/// # tokio_test::block_on(async {
196
- /// let config = CachedResolverConfig { dns_max_ttl: 10 } ;
197
- /// let resolver = CachedResolver::new(config).await.unwrap();
238
+ /// let config = CachedResolverConfig::default() ;
239
+ /// let resolver = CachedResolver::new(config, None ).await.unwrap();
198
240
/// let response = resolver.lookup_ip("www.google.com.");
199
241
/// # })
200
242
/// ```
@@ -208,10 +250,14 @@ impl CachedResolver {
208
250
}
209
251
None => {
210
252
debug ! ( "Not found, executing a dns query!" ) ;
211
- let addr_set = AddrSet :: from ( self . resolver . lookup_ip ( host) . await ?) ;
212
- debug ! ( "Obtained: {:?}" , addr_set) ;
213
- self . store_in_cache ( host, addr_set. clone ( ) ) ;
214
- Ok ( addr_set)
253
+ if let Some ( ref resolver) = self . resolver {
254
+ let addr_set = AddrSet :: from ( resolver. lookup_ip ( host) . await ?) ;
255
+ debug ! ( "Obtained: {:?}" , addr_set) ;
256
+ self . store_in_cache ( host, addr_set. clone ( ) ) ;
257
+ Ok ( addr_set)
258
+ } else {
259
+ Err ( ResolveError :: from ( "No resolver available" ) )
260
+ }
215
261
}
216
262
}
217
263
}
@@ -227,72 +273,89 @@ impl CachedResolver {
227
273
228
274
// Fetches an AddrSet from the inner cache adquiring the read lock.
229
275
fn fetch_from_cache ( & self , key : & str ) -> Option < AddrSet > {
230
- let hash = & self . data . read ( ) . unwrap ( ) ;
231
- if let Some ( addr_set) = hash. get ( key) {
232
- return Some ( addr_set. clone ( ) ) ;
276
+ if let Some ( ref hash) = self . data {
277
+ if let Some ( addr_set) = hash. read ( ) . unwrap ( ) . get ( key) {
278
+ return Some ( addr_set. clone ( ) ) ;
279
+ }
233
280
}
234
281
None
235
282
}
236
283
237
284
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
238
285
// cache.
239
286
pub async fn from_config ( ) -> Result < ( ) , Error > {
240
- let config = get_config ( ) ;
287
+ let cached_resolver = CACHED_RESOLVER . load ( ) ;
288
+ let desired_config = CachedResolverConfig :: from ( get_config ( ) ) ;
241
289
242
- // Configure dns_cache if enabled
243
- if config. general . dns_cache_enabled {
244
- info ! ( "Starting Dns cache" ) ;
245
- let cached_resolver_config = CachedResolverConfig {
246
- dns_max_ttl : config. general . dns_max_ttl ,
290
+ if cached_resolver. config != desired_config {
291
+ if let Some ( ref refresh_loop) = * ( cached_resolver. refresh_loop . write ( ) . unwrap ( ) ) {
292
+ warn ! ( "Killing Dnscache refresh loop as its configuration is being reloaded" ) ;
293
+ refresh_loop. abort ( )
294
+ }
295
+ let new_resolver = if let Some ( ref data) = cached_resolver. data {
296
+ let data = Some ( data. read ( ) . unwrap ( ) . clone ( ) ) ;
297
+ CachedResolver :: new ( desired_config, data) . await
298
+ } else {
299
+ CachedResolver :: new ( desired_config, None ) . await
247
300
} ;
248
- return match CachedResolver :: new ( cached_resolver_config) . await {
301
+
302
+ match new_resolver {
249
303
Ok ( ok) => {
250
- let value = Some ( ArcSwap :: from ( ok) ) ;
251
- CACHED_RESOLVER . store ( Arc :: new ( value) ) ;
304
+ CACHED_RESOLVER . store ( ok) ;
252
305
Ok ( ( ) )
253
306
}
254
307
Err ( err) => {
255
- let message = format ! ( "Error Starting cached_resolver error : {:?}, will continue without this feature." , err) ;
308
+ let message = format ! ( "Error setting up cached_resolver. Error : {:?}, will continue without this feature." , err) ;
256
309
Err ( Error :: DNSCachedError ( message) )
257
310
}
258
- } ;
311
+ }
312
+ } else {
313
+ Ok ( ( ) )
259
314
}
260
- Ok ( ( ) )
261
315
}
262
316
263
317
// Stores the AddrSet in cache adquiring the write lock.
264
318
fn store_in_cache ( & self , host : & str , addr_set : AddrSet ) {
265
- self . data
266
- . write ( )
267
- . unwrap ( )
268
- . insert ( host. to_string ( ) , addr_set) ;
319
+ if let Some ( ref data) = self . data {
320
+ data. write ( ) . unwrap ( ) . insert ( host. to_string ( ) , addr_set) ;
321
+ } else {
322
+ error ! ( "Could not insert, Hash not initialized" ) ;
323
+ }
269
324
}
270
325
}
271
-
272
326
#[ cfg( test) ]
273
327
mod tests {
274
328
use super :: * ;
275
329
use trust_dns_resolver:: error:: ResolveError ;
276
330
277
331
#[ tokio:: test]
278
332
async fn new ( ) {
279
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
280
- let resolver = CachedResolver :: new ( config) . await ;
333
+ let config = CachedResolverConfig {
334
+ dns_max_ttl : 10 ,
335
+ enabled : true ,
336
+ } ;
337
+ let resolver = CachedResolver :: new ( config, None ) . await ;
281
338
assert ! ( resolver. is_ok( ) ) ;
282
339
}
283
340
284
341
#[ tokio:: test]
285
342
async fn lookup_ip ( ) {
286
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
287
- let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
343
+ let config = CachedResolverConfig {
344
+ dns_max_ttl : 10 ,
345
+ enabled : true ,
346
+ } ;
347
+ let resolver = CachedResolver :: new ( config, None ) . await . unwrap ( ) ;
288
348
let response = resolver. lookup_ip ( "www.google.com." ) . await ;
289
349
assert ! ( response. is_ok( ) ) ;
290
350
}
291
351
292
352
#[ tokio:: test]
293
353
async fn has_changed ( ) {
294
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
295
- let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
354
+ let config = CachedResolverConfig {
355
+ dns_max_ttl : 10 ,
356
+ enabled : true ,
357
+ } ;
358
+ let resolver = CachedResolver :: new ( config, None ) . await . unwrap ( ) ;
296
359
let hostname = "www.google.com." ;
297
360
let response = resolver. lookup_ip ( hostname) . await ;
298
361
let addr_set = response. unwrap ( ) ;
@@ -301,17 +364,23 @@ mod tests {
301
364
302
365
#[ tokio:: test]
303
366
async fn unknown_host ( ) {
304
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
305
- let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
367
+ let config = CachedResolverConfig {
368
+ dns_max_ttl : 10 ,
369
+ enabled : true ,
370
+ } ;
371
+ let resolver = CachedResolver :: new ( config, None ) . await . unwrap ( ) ;
306
372
let hostname = "www.idontexists." ;
307
373
let response = resolver. lookup_ip ( hostname) . await ;
308
374
assert ! ( matches!( response, Err ( ResolveError { .. } ) ) ) ;
309
375
}
310
376
311
377
#[ tokio:: test]
312
378
async fn incorrect_address ( ) {
313
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
314
- let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
379
+ let config = CachedResolverConfig {
380
+ dns_max_ttl : 10 ,
381
+ enabled : true ,
382
+ } ;
383
+ let resolver = CachedResolver :: new ( config, None ) . await . unwrap ( ) ;
315
384
let hostname = "w ww.idontexists." ;
316
385
let response = resolver. lookup_ip ( hostname) . await ;
317
386
assert ! ( matches!( response, Err ( ResolveError { .. } ) ) ) ;
@@ -323,9 +392,11 @@ mod tests {
323
392
// and does not responds with every available ip everytime, so
324
393
// if I cache here, it will miss after one cache iteration or two.
325
394
async fn thread ( ) {
326
- env_logger:: init ( ) ;
327
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
328
- let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
395
+ let config = CachedResolverConfig {
396
+ dns_max_ttl : 10 ,
397
+ enabled : true ,
398
+ } ;
399
+ let resolver = CachedResolver :: new ( config, None ) . await . unwrap ( ) ;
329
400
let hostname = "www.google.com." ;
330
401
let response = resolver. lookup_ip ( hostname) . await ;
331
402
let addr_set = response. unwrap ( ) ;
0 commit comments