1
1
use crate :: config:: get_config;
2
2
use crate :: errors:: Error ;
3
3
use arc_swap:: ArcSwap ;
4
- use log:: { debug, error, info} ;
4
+ use log:: { debug, error, info, warn } ;
5
5
use once_cell:: sync:: Lazy ;
6
6
use std:: collections:: { HashMap , HashSet } ;
7
7
use std:: io;
8
8
use std:: net:: IpAddr ;
9
9
use std:: sync:: Arc ;
10
10
use std:: sync:: RwLock ;
11
11
use tokio:: time:: { sleep, Duration } ;
12
- use trust_dns_resolver:: error:: ResolveResult ;
12
+ use trust_dns_resolver:: error:: { ResolveError , ResolveResult } ;
13
13
use trust_dns_resolver:: lookup_ip:: LookupIp ;
14
14
use trust_dns_resolver:: TokioAsyncResolver ;
15
15
16
16
/// Cached Resolver Globally available
17
- pub static CACHED_RESOLVER : Lazy < ArcSwap < Option < ArcSwap < CachedResolver > > > > =
18
- Lazy :: new ( || ArcSwap :: from_pointee ( None ) ) ;
17
+ pub static CACHED_RESOLVER : Lazy < ArcSwap < CachedResolver > > =
18
+ Lazy :: new ( || ArcSwap :: from_pointee ( CachedResolver :: default ( ) ) ) ;
19
19
20
20
// Ip addressed are returned as a set of addresses
21
21
// so we can compare.
@@ -70,23 +70,45 @@ impl From<LookupIp> for AddrSet {
70
70
/// // You can now check if an 'old' lookup differs from what it's currently
71
71
/// // store in cache by using `has_changed`.
72
72
/// resolver.has_changed("www.example.com.", addrset)
73
+ #[ derive( Default ) ]
73
74
pub struct CachedResolver {
74
75
// The configuration of the cached_resolver.
75
76
config : CachedResolverConfig ,
76
77
77
78
// This is the hash that contains the hash.
78
- data : Arc < RwLock < HashMap < String , AddrSet > > > ,
79
+ data : Option < RwLock < HashMap < String , AddrSet > > > ,
79
80
80
81
// The resolver to be used for DNS queries.
81
- resolver : Arc < TokioAsyncResolver > ,
82
+ resolver : Option < TokioAsyncResolver > ,
83
+
84
+ // The RefreshLoop
85
+ refresh_loop : RwLock < Option < tokio:: task:: JoinHandle < ( ) > > > ,
82
86
}
83
87
84
88
///
85
89
/// Configuration
86
- #[ derive( Clone , Debug ) ]
90
+ #[ derive( Clone , Debug , Default , PartialEq ) ]
87
91
pub struct CachedResolverConfig {
88
92
/// Amount of time in secods that a resolved dns address is considered stale.
89
- pub dns_max_ttl : u64 ,
93
+ dns_max_ttl : u64 ,
94
+
95
+ /// Enabled or disabled? (this is so we can reload config)
96
+ enabled : bool ,
97
+ }
98
+
99
+ impl CachedResolverConfig {
100
+ fn new ( dns_max_ttl : u64 , enabled : bool ) -> Self {
101
+ CachedResolverConfig {
102
+ dns_max_ttl,
103
+ enabled,
104
+ }
105
+ }
106
+ }
107
+
108
+ impl From < crate :: config:: Config > for CachedResolverConfig {
109
+ fn from ( config : crate :: config:: Config ) -> Self {
110
+ CachedResolverConfig :: new ( config. general . dns_max_ttl , config. general . dns_cache_enabled )
111
+ }
90
112
}
91
113
92
114
impl CachedResolver {
@@ -109,24 +131,39 @@ impl CachedResolver {
109
131
/// # })
110
132
/// ```
111
133
///
112
- pub async fn new ( config : CachedResolverConfig ) -> io :: Result < Arc < Self > > {
134
+ pub async fn new ( config : CachedResolverConfig , data : Option < HashMap < String , AddrSet > > ) -> Result < Arc < Self > , io :: Error > {
113
135
// Construct a new Resolver with default configuration options
114
- let resolver = Arc :: new ( TokioAsyncResolver :: tokio_from_system_conf ( ) ?) ;
115
- let data = Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ;
136
+ let resolver = Some ( TokioAsyncResolver :: tokio_from_system_conf ( ) ?) ;
137
+
138
+ let data = if let Some ( hash) = data {
139
+ Some ( RwLock :: new ( hash) )
140
+ } else {
141
+ Some ( RwLock :: new ( HashMap :: new ( ) ) )
142
+ } ;
116
143
117
- let self_ref = Arc :: new ( Self {
144
+ let instance = Arc :: new ( Self {
118
145
config,
119
146
resolver,
120
147
data,
148
+ refresh_loop : RwLock :: new ( None ) ,
121
149
} ) ;
122
- let clone_self_ref = self_ref. clone ( ) ;
123
150
124
- info ! ( "Scheduling DNS refresh loop" ) ;
125
- tokio:: task:: spawn ( async move {
126
- clone_self_ref. refresh_dns_entries_loop ( ) . await ;
127
- } ) ;
151
+ if instance. enabled ( ) {
152
+ info ! ( "Scheduling DNS refresh loop" ) ;
153
+ let refresh_loop = tokio:: task:: spawn ( {
154
+ let instance = instance. clone ( ) ;
155
+ async move {
156
+ instance. refresh_dns_entries_loop ( ) . await ;
157
+ }
158
+ } ) ;
159
+ * ( instance. refresh_loop . write ( ) . unwrap ( ) ) = Some ( refresh_loop) ;
160
+ }
161
+
162
+ Ok ( instance)
163
+ }
128
164
129
- Ok ( self_ref)
165
+ pub fn enabled ( & self ) -> bool {
166
+ self . config . enabled
130
167
}
131
168
132
169
// Schedules the refresher
@@ -139,8 +176,10 @@ impl CachedResolver {
139
176
// an array with keys.
140
177
let mut hostnames: Vec < String > = Vec :: new ( ) ;
141
178
{
142
- for hostname in self . data . read ( ) . unwrap ( ) . keys ( ) {
143
- hostnames. push ( hostname. clone ( ) ) ;
179
+ if let Some ( ref data) = self . data {
180
+ for hostname in data. read ( ) . unwrap ( ) . keys ( ) {
181
+ hostnames. push ( hostname. clone ( ) ) ;
182
+ }
144
183
}
145
184
}
146
185
@@ -208,10 +247,14 @@ impl CachedResolver {
208
247
}
209
248
None => {
210
249
debug ! ( "Not found, executing a dns query!" ) ;
211
- let addr_set = AddrSet :: from ( self . resolver . lookup_ip ( host) . await ?) ;
212
- debug ! ( "Obtained: {:?}" , addr_set) ;
213
- self . store_in_cache ( host, addr_set. clone ( ) ) ;
214
- Ok ( addr_set)
250
+ if let Some ( ref resolver) = self . resolver {
251
+ let addr_set = AddrSet :: from ( resolver. lookup_ip ( host) . await ?) ;
252
+ debug ! ( "Obtained: {:?}" , addr_set) ;
253
+ self . store_in_cache ( host, addr_set. clone ( ) ) ;
254
+ Ok ( addr_set)
255
+ } else {
256
+ Err ( ResolveError :: from ( "No resolver available" ) )
257
+ }
215
258
}
216
259
}
217
260
}
@@ -227,71 +270,92 @@ impl CachedResolver {
227
270
228
271
// Fetches an AddrSet from the inner cache adquiring the read lock.
229
272
fn fetch_from_cache ( & self , key : & str ) -> Option < AddrSet > {
230
- let hash = & self . data . read ( ) . unwrap ( ) ;
231
- if let Some ( addr_set) = hash. get ( key) {
232
- return Some ( addr_set. clone ( ) ) ;
273
+ if let Some ( ref hash) = self . data {
274
+ if let Some ( addr_set) = hash. read ( ) . unwrap ( ) . get ( key) {
275
+ return Some ( addr_set. clone ( ) ) ;
276
+ }
233
277
}
234
278
None
235
279
}
236
280
237
281
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
238
282
// cache.
239
283
pub async fn from_config ( ) -> Result < ( ) , Error > {
240
- let config = get_config ( ) ;
241
-
242
- // Configure dns_cache if enabled
243
- if config. general . dns_cache_enabled {
244
- info ! ( "Starting Dns cache" ) ;
245
- let cached_resolver_config = CachedResolverConfig {
246
- dns_max_ttl : config. general . dns_max_ttl ,
247
- } ;
248
- return match CachedResolver :: new ( cached_resolver_config) . await {
284
+ let cached_resolver = CACHED_RESOLVER . load ( ) ;
285
+ let desired_config = CachedResolverConfig :: from ( get_config ( ) ) ;
286
+
287
+ if cached_resolver. config != desired_config {
288
+ if let Some ( ref refresh_loop) = * ( cached_resolver. refresh_loop . write ( ) . unwrap ( ) ) {
289
+ warn ! ( "Killing Dnscache refresh loop as its configuration is being reloaded" ) ;
290
+ refresh_loop. abort ( )
291
+ }
292
+ let new_resolver = if let Some ( ref data) = cached_resolver. data {
293
+ let data = Some ( data. read ( ) . unwrap ( ) . clone ( ) ) ;
294
+ CachedResolver :: new ( desired_config, data) . await
295
+ } else {
296
+ CachedResolver :: new ( desired_config, None ) . await
297
+ } ;
298
+
299
+ match new_resolver {
249
300
Ok ( ok) => {
250
- let value = Some ( ArcSwap :: from ( ok) ) ;
251
- CACHED_RESOLVER . store ( Arc :: new ( value) ) ;
252
- Ok ( ( ) )
301
+ CACHED_RESOLVER . store ( ok) ;
302
+ Ok ( ( ) )
253
303
}
254
304
Err ( err) => {
255
- let message = format ! ( "Error Starting cached_resolver error : {:?}, will continue without this feature." , err) ;
256
- Err ( Error :: DNSCachedError ( message) )
305
+ let message = format ! ( "Error setting up cached_resolver. Error : {:?}, will continue without this feature." , err) ;
306
+ Err ( Error :: DNSCachedError ( message) )
257
307
}
258
- } ;
259
- }
260
- Ok ( ( ) )
308
+ }
309
+ } else {
310
+ Ok ( ( ) )
311
+ }
261
312
}
262
313
263
314
// Stores the AddrSet in cache adquiring the write lock.
264
315
fn store_in_cache ( & self , host : & str , addr_set : AddrSet ) {
265
- self . data
266
- . write ( )
267
- . unwrap ( )
268
- . insert ( host. to_string ( ) , addr_set) ;
316
+ if let Some ( ref data) = self . data {
317
+ data
318
+ . write ( )
319
+ . unwrap ( )
320
+ . insert ( host. to_string ( ) , addr_set) ;
321
+ } else {
322
+ error ! ( "Could not insert, Hash not initialized" ) ;
323
+ }
269
324
}
270
- }
271
325
326
+ }
272
327
#[ cfg( test) ]
273
328
mod tests {
274
329
use super :: * ;
275
330
use trust_dns_resolver:: error:: ResolveError ;
276
331
277
332
#[ tokio:: test]
278
333
async fn new ( ) {
279
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
334
+ let config = CachedResolverConfig {
335
+ dns_max_ttl : 10 ,
336
+ enabled : true ,
337
+ } ;
280
338
let resolver = CachedResolver :: new ( config) . await ;
281
339
assert ! ( resolver. is_ok( ) ) ;
282
340
}
283
341
284
342
#[ tokio:: test]
285
343
async fn lookup_ip ( ) {
286
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
344
+ let config = CachedResolverConfig {
345
+ dns_max_ttl : 10 ,
346
+ enabled : true ,
347
+ } ;
287
348
let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
288
349
let response = resolver. lookup_ip ( "www.google.com." ) . await ;
289
350
assert ! ( response. is_ok( ) ) ;
290
351
}
291
352
292
353
#[ tokio:: test]
293
354
async fn has_changed ( ) {
294
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
355
+ let config = CachedResolverConfig {
356
+ dns_max_ttl : 10 ,
357
+ enabled : true ,
358
+ } ;
295
359
let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
296
360
let hostname = "www.google.com." ;
297
361
let response = resolver. lookup_ip ( hostname) . await ;
@@ -301,7 +365,10 @@ mod tests {
301
365
302
366
#[ tokio:: test]
303
367
async fn unknown_host ( ) {
304
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
368
+ let config = CachedResolverConfig {
369
+ dns_max_ttl : 10 ,
370
+ enabled : true ,
371
+ } ;
305
372
let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
306
373
let hostname = "www.idontexists." ;
307
374
let response = resolver. lookup_ip ( hostname) . await ;
@@ -310,7 +377,10 @@ mod tests {
310
377
311
378
#[ tokio:: test]
312
379
async fn incorrect_address ( ) {
313
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
380
+ let config = CachedResolverConfig {
381
+ dns_max_ttl : 10 ,
382
+ enabled : true ,
383
+ } ;
314
384
let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
315
385
let hostname = "w ww.idontexists." ;
316
386
let response = resolver. lookup_ip ( hostname) . await ;
@@ -324,7 +394,10 @@ mod tests {
324
394
// if I cache here, it will miss after one cache iteration or two.
325
395
async fn thread ( ) {
326
396
env_logger:: init ( ) ;
327
- let config = CachedResolverConfig { dns_max_ttl : 10 } ;
397
+ let config = CachedResolverConfig {
398
+ dns_max_ttl : 10 ,
399
+ enabled : true ,
400
+ } ;
328
401
let resolver = CachedResolver :: new ( config) . await . unwrap ( ) ;
329
402
let hostname = "www.google.com." ;
330
403
let response = resolver. lookup_ip ( hostname) . await ;
0 commit comments