Skip to content

Commit ab932a9

Browse files
committed
Allow reloading dns cached
1 parent 621ac95 commit ab932a9

File tree

3 files changed

+157
-77
lines changed

3 files changed

+157
-77
lines changed

src/config.rs

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::sync::Arc;
1212
use tokio::fs::File;
1313
use tokio::io::AsyncReadExt;
1414

15+
use crate::dns_cache::CachedResolver;
1516
use crate::errors::Error;
1617
use crate::pool::{ClientServerMap, ConnectionPool};
1718
use crate::sharding::ShardingFunction;
@@ -878,6 +879,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
878879
}
879880
};
880881
let new_config = get_config();
882+
match CachedResolver::from_config().await {
883+
Ok(_) => (),
884+
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
885+
};
881886

882887
if old_config.pools != new_config.pools {
883888
info!("Pool configuration changed");

src/dns_cache.rs

+128-55
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
use crate::config::get_config;
22
use crate::errors::Error;
33
use arc_swap::ArcSwap;
4-
use log::{debug, error, info};
4+
use log::{debug, error, info, warn};
55
use once_cell::sync::Lazy;
66
use std::collections::{HashMap, HashSet};
77
use std::io;
88
use std::net::IpAddr;
99
use std::sync::Arc;
1010
use std::sync::RwLock;
1111
use tokio::time::{sleep, Duration};
12-
use trust_dns_resolver::error::ResolveResult;
12+
use trust_dns_resolver::error::{ResolveError, ResolveResult};
1313
use trust_dns_resolver::lookup_ip::LookupIp;
1414
use trust_dns_resolver::TokioAsyncResolver;
1515

1616
/// Cached Resolver Globally available
17-
pub static CACHED_RESOLVER: Lazy<ArcSwap<Option<ArcSwap<CachedResolver>>>> =
18-
Lazy::new(|| ArcSwap::from_pointee(None));
17+
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
18+
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
1919

2020
// Ip addressed are returned as a set of addresses
2121
// so we can compare.
@@ -70,23 +70,45 @@ impl From<LookupIp> for AddrSet {
7070
/// // You can now check if an 'old' lookup differs from what it's currently
7171
/// // store in cache by using `has_changed`.
7272
/// resolver.has_changed("www.example.com.", addrset)
73+
#[derive(Default)]
7374
pub struct CachedResolver {
7475
// The configuration of the cached_resolver.
7576
config: CachedResolverConfig,
7677

7778
// This is the hash that contains the hash.
78-
data: Arc<RwLock<HashMap<String, AddrSet>>>,
79+
data: Option<RwLock<HashMap<String, AddrSet>>>,
7980

8081
// The resolver to be used for DNS queries.
81-
resolver: Arc<TokioAsyncResolver>,
82+
resolver: Option<TokioAsyncResolver>,
83+
84+
// The RefreshLoop
85+
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
8286
}
8387

8488
///
8589
/// Configuration
86-
#[derive(Clone, Debug)]
90+
#[derive(Clone, Debug, Default, PartialEq)]
8791
pub struct CachedResolverConfig {
8892
/// Amount of time in secods that a resolved dns address is considered stale.
89-
pub dns_max_ttl: u64,
93+
dns_max_ttl: u64,
94+
95+
/// Enabled or disabled? (this is so we can reload config)
96+
enabled: bool,
97+
}
98+
99+
impl CachedResolverConfig {
100+
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
101+
CachedResolverConfig {
102+
dns_max_ttl,
103+
enabled,
104+
}
105+
}
106+
}
107+
108+
impl From<crate::config::Config> for CachedResolverConfig {
109+
fn from(config: crate::config::Config) -> Self {
110+
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
111+
}
90112
}
91113

92114
impl CachedResolver {
@@ -109,24 +131,39 @@ impl CachedResolver {
109131
/// # })
110132
/// ```
111133
///
112-
pub async fn new(config: CachedResolverConfig) -> io::Result<Arc<Self>> {
134+
pub async fn new(config: CachedResolverConfig, data: Option<HashMap<String, AddrSet>>) -> Result<Arc<Self>, io::Error> {
113135
// Construct a new Resolver with default configuration options
114-
let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?);
115-
let data = Arc::new(RwLock::new(HashMap::new()));
136+
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
137+
138+
let data = if let Some(hash) = data {
139+
Some(RwLock::new(hash))
140+
} else {
141+
Some(RwLock::new(HashMap::new()))
142+
};
116143

117-
let self_ref = Arc::new(Self {
144+
let instance = Arc::new(Self {
118145
config,
119146
resolver,
120147
data,
148+
refresh_loop: RwLock::new(None),
121149
});
122-
let clone_self_ref = self_ref.clone();
123150

124-
info!("Scheduling DNS refresh loop");
125-
tokio::task::spawn(async move {
126-
clone_self_ref.refresh_dns_entries_loop().await;
127-
});
151+
if instance.enabled() {
152+
info!("Scheduling DNS refresh loop");
153+
let refresh_loop = tokio::task::spawn({
154+
let instance = instance.clone();
155+
async move {
156+
instance.refresh_dns_entries_loop().await;
157+
}
158+
});
159+
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
160+
}
161+
162+
Ok(instance)
163+
}
128164

129-
Ok(self_ref)
165+
pub fn enabled(&self) -> bool {
166+
self.config.enabled
130167
}
131168

132169
// Schedules the refresher
@@ -139,8 +176,10 @@ impl CachedResolver {
139176
// an array with keys.
140177
let mut hostnames: Vec<String> = Vec::new();
141178
{
142-
for hostname in self.data.read().unwrap().keys() {
143-
hostnames.push(hostname.clone());
179+
if let Some(ref data) = self.data {
180+
for hostname in data.read().unwrap().keys() {
181+
hostnames.push(hostname.clone());
182+
}
144183
}
145184
}
146185

@@ -208,10 +247,14 @@ impl CachedResolver {
208247
}
209248
None => {
210249
debug!("Not found, executing a dns query!");
211-
let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?);
212-
debug!("Obtained: {:?}", addr_set);
213-
self.store_in_cache(host, addr_set.clone());
214-
Ok(addr_set)
250+
if let Some(ref resolver) = self.resolver {
251+
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
252+
debug!("Obtained: {:?}", addr_set);
253+
self.store_in_cache(host, addr_set.clone());
254+
Ok(addr_set)
255+
} else {
256+
Err(ResolveError::from("No resolver available"))
257+
}
215258
}
216259
}
217260
}
@@ -227,71 +270,92 @@ impl CachedResolver {
227270

228271
// Fetches an AddrSet from the inner cache adquiring the read lock.
229272
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
230-
let hash = &self.data.read().unwrap();
231-
if let Some(addr_set) = hash.get(key) {
232-
return Some(addr_set.clone());
273+
if let Some(ref hash) = self.data {
274+
if let Some(addr_set) = hash.read().unwrap().get(key) {
275+
return Some(addr_set.clone());
276+
}
233277
}
234278
None
235279
}
236280

237281
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
238282
// cache.
239283
pub async fn from_config() -> Result<(), Error> {
240-
let config = get_config();
241-
242-
// Configure dns_cache if enabled
243-
if config.general.dns_cache_enabled {
244-
info!("Starting Dns cache");
245-
let cached_resolver_config = CachedResolverConfig {
246-
dns_max_ttl: config.general.dns_max_ttl,
247-
};
248-
return match CachedResolver::new(cached_resolver_config).await {
284+
let cached_resolver = CACHED_RESOLVER.load();
285+
let desired_config = CachedResolverConfig::from(get_config());
286+
287+
if cached_resolver.config != desired_config {
288+
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
289+
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
290+
refresh_loop.abort()
291+
}
292+
let new_resolver = if let Some(ref data) = cached_resolver.data {
293+
let data = Some(data.read().unwrap().clone());
294+
CachedResolver::new(desired_config, data).await
295+
} else {
296+
CachedResolver::new(desired_config, None).await
297+
};
298+
299+
match new_resolver {
249300
Ok(ok) => {
250-
let value = Some(ArcSwap::from(ok));
251-
CACHED_RESOLVER.store(Arc::new(value));
252-
Ok(())
301+
CACHED_RESOLVER.store(ok);
302+
Ok(())
253303
}
254304
Err(err) => {
255-
let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err);
256-
Err(Error::DNSCachedError(message))
305+
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
306+
Err(Error::DNSCachedError(message))
257307
}
258-
};
259-
}
260-
Ok(())
308+
}
309+
} else {
310+
Ok(())
311+
}
261312
}
262313

263314
// Stores the AddrSet in cache adquiring the write lock.
264315
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
265-
self.data
266-
.write()
267-
.unwrap()
268-
.insert(host.to_string(), addr_set);
316+
if let Some(ref data) = self.data {
317+
data
318+
.write()
319+
.unwrap()
320+
.insert(host.to_string(), addr_set);
321+
} else {
322+
error!("Could not insert, Hash not initialized");
323+
}
269324
}
270-
}
271325

326+
}
272327
#[cfg(test)]
273328
mod tests {
274329
use super::*;
275330
use trust_dns_resolver::error::ResolveError;
276331

277332
#[tokio::test]
278333
async fn new() {
279-
let config = CachedResolverConfig { dns_max_ttl: 10 };
334+
let config = CachedResolverConfig {
335+
dns_max_ttl: 10,
336+
enabled: true,
337+
};
280338
let resolver = CachedResolver::new(config).await;
281339
assert!(resolver.is_ok());
282340
}
283341

284342
#[tokio::test]
285343
async fn lookup_ip() {
286-
let config = CachedResolverConfig { dns_max_ttl: 10 };
344+
let config = CachedResolverConfig {
345+
dns_max_ttl: 10,
346+
enabled: true,
347+
};
287348
let resolver = CachedResolver::new(config).await.unwrap();
288349
let response = resolver.lookup_ip("www.google.com.").await;
289350
assert!(response.is_ok());
290351
}
291352

292353
#[tokio::test]
293354
async fn has_changed() {
294-
let config = CachedResolverConfig { dns_max_ttl: 10 };
355+
let config = CachedResolverConfig {
356+
dns_max_ttl: 10,
357+
enabled: true,
358+
};
295359
let resolver = CachedResolver::new(config).await.unwrap();
296360
let hostname = "www.google.com.";
297361
let response = resolver.lookup_ip(hostname).await;
@@ -301,7 +365,10 @@ mod tests {
301365

302366
#[tokio::test]
303367
async fn unknown_host() {
304-
let config = CachedResolverConfig { dns_max_ttl: 10 };
368+
let config = CachedResolverConfig {
369+
dns_max_ttl: 10,
370+
enabled: true,
371+
};
305372
let resolver = CachedResolver::new(config).await.unwrap();
306373
let hostname = "www.idontexists.";
307374
let response = resolver.lookup_ip(hostname).await;
@@ -310,7 +377,10 @@ mod tests {
310377

311378
#[tokio::test]
312379
async fn incorrect_address() {
313-
let config = CachedResolverConfig { dns_max_ttl: 10 };
380+
let config = CachedResolverConfig {
381+
dns_max_ttl: 10,
382+
enabled: true,
383+
};
314384
let resolver = CachedResolver::new(config).await.unwrap();
315385
let hostname = "w ww.idontexists.";
316386
let response = resolver.lookup_ip(hostname).await;
@@ -324,7 +394,10 @@ mod tests {
324394
// if I cache here, it will miss after one cache iteration or two.
325395
async fn thread() {
326396
env_logger::init();
327-
let config = CachedResolverConfig { dns_max_ttl: 10 };
397+
let config = CachedResolverConfig {
398+
dns_max_ttl: 10,
399+
enabled: true,
400+
};
328401
let resolver = CachedResolver::new(config).await.unwrap();
329402
let hostname = "www.google.com.";
330403
let response = resolver.lookup_ip(hostname).await;

src/server.rs

+24-22
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,21 @@ impl Server {
9090
stats: Reporter,
9191
) -> Result<Server, Error> {
9292
let cached_resolver = CACHED_RESOLVER.load();
93-
let addr_set = match cached_resolver.as_ref() {
94-
Some(cached_resolver) => {
95-
if address.host.parse::<IpAddr>().is_err() {
96-
debug!("Resolving {}", &address.host);
97-
match cached_resolver.load().lookup_ip(&address.host).await {
98-
Ok(ok) => {
99-
debug!("Obtained: {:?}", ok);
100-
Some(ok)
101-
}
102-
Err(err) => {
103-
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
104-
None
105-
}
106-
}
107-
} else {
93+
let mut addr_set: Option<AddrSet> = None;
94+
95+
// If we are caching addresses and hostname is not an IP
96+
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
97+
debug!("Resolving {}", &address.host);
98+
addr_set = match cached_resolver.lookup_ip(&address.host).await {
99+
Ok(ok) => {
100+
debug!("Obtained: {:?}", ok);
101+
Some(ok)
102+
}
103+
Err(err) => {
104+
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
108105
None
109106
}
110107
}
111-
None => None,
112108
};
113109

114110
let mut stream =
@@ -604,13 +600,19 @@ impl Server {
604600
if self.bad {
605601
return self.bad;
606602
};
607-
608-
if let Some(cached_resolver) = CACHED_RESOLVER.load().as_ref() {
603+
let cached_resolver = CACHED_RESOLVER.load();
604+
if cached_resolver.enabled() {
609605
if let Some(addr_set) = &self.addr_set {
610-
if cached_resolver.load().has_changed(self.address.host.as_str(), addr_set) {
611-
warn!("DNS changed for {}, it was {:?}. Dropping server connection.", self.address.host.as_str(), addr_set);
612-
return true
613-
}
606+
if cached_resolver
607+
.has_changed(self.address.host.as_str(), addr_set)
608+
{
609+
warn!(
610+
"DNS changed for {}, it was {:?}. Dropping server connection.",
611+
self.address.host.as_str(),
612+
addr_set
613+
);
614+
return true;
615+
}
614616
}
615617
}
616618
false

0 commit comments

Comments
 (0)