Skip to content

Commit 9a4eac1

Browse files
committed
Allow reloading dns cached
1 parent e2f1aa2 commit 9a4eac1

File tree

4 files changed

+154
-95
lines changed

4 files changed

+154
-95
lines changed

Cargo.lock

-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/config.rs

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::sync::Arc;
1212
use tokio::fs::File;
1313
use tokio::io::AsyncReadExt;
1414

15+
use crate::dns_cache::CachedResolver;
1516
use crate::errors::Error;
1617
use crate::pool::{ClientServerMap, ConnectionPool};
1718
use crate::sharding::ShardingFunction;
@@ -1032,6 +1033,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
10321033
}
10331034
};
10341035
let new_config = get_config();
1036+
match CachedResolver::from_config().await {
1037+
Ok(_) => (),
1038+
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
1039+
};
10351040

10361041
if old_config.pools != new_config.pools {
10371042
info!("Pool configuration changed");

src/dns_cache.rs

+134-63
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
use crate::config::get_config;
22
use crate::errors::Error;
33
use arc_swap::ArcSwap;
4-
use log::{debug, error, info};
4+
use log::{debug, error, info, warn};
55
use once_cell::sync::Lazy;
66
use std::collections::{HashMap, HashSet};
77
use std::io;
88
use std::net::IpAddr;
99
use std::sync::Arc;
1010
use std::sync::RwLock;
1111
use tokio::time::{sleep, Duration};
12-
use trust_dns_resolver::error::ResolveResult;
12+
use trust_dns_resolver::error::{ResolveError, ResolveResult};
1313
use trust_dns_resolver::lookup_ip::LookupIp;
1414
use trust_dns_resolver::TokioAsyncResolver;
1515

1616
/// Cached Resolver Globally available
17-
pub static CACHED_RESOLVER: Lazy<ArcSwap<Option<ArcSwap<CachedResolver>>>> =
18-
Lazy::new(|| ArcSwap::from_pointee(None));
17+
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
18+
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
1919

2020
// Ip addressed are returned as a set of addresses
2121
// so we can compare.
@@ -57,8 +57,8 @@ impl From<LookupIp> for AddrSet {
5757
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
5858
///
5959
/// # tokio_test::block_on(async {
60-
/// let config = CachedResolverConfig{dns_max_ttl: 10};
61-
/// let resolver = CachedResolver::new(config).await.unwrap();
60+
/// let config = CachedResolverConfig::default();
61+
/// let resolver = CachedResolver::new(config, None).await.unwrap();
6262
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
6363
/// # })
6464
/// ```
@@ -70,23 +70,45 @@ impl From<LookupIp> for AddrSet {
7070
/// // You can now check if an 'old' lookup differs from what it's currently
7171
/// // store in cache by using `has_changed`.
7272
/// resolver.has_changed("www.example.com.", addrset)
73+
#[derive(Default)]
7374
pub struct CachedResolver {
7475
// The configuration of the cached_resolver.
7576
config: CachedResolverConfig,
7677

7778
// This is the hash that contains the hash.
78-
data: Arc<RwLock<HashMap<String, AddrSet>>>,
79+
data: Option<RwLock<HashMap<String, AddrSet>>>,
7980

8081
// The resolver to be used for DNS queries.
81-
resolver: Arc<TokioAsyncResolver>,
82+
resolver: Option<TokioAsyncResolver>,
83+
84+
// The RefreshLoop
85+
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
8286
}
8387

8488
///
8589
/// Configuration
86-
#[derive(Clone, Debug)]
90+
#[derive(Clone, Debug, Default, PartialEq)]
8791
pub struct CachedResolverConfig {
8892
/// Amount of time in secods that a resolved dns address is considered stale.
89-
pub dns_max_ttl: u64,
93+
dns_max_ttl: u64,
94+
95+
/// Enabled or disabled? (this is so we can reload config)
96+
enabled: bool,
97+
}
98+
99+
impl CachedResolverConfig {
100+
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
101+
CachedResolverConfig {
102+
dns_max_ttl,
103+
enabled,
104+
}
105+
}
106+
}
107+
108+
impl From<crate::config::Config> for CachedResolverConfig {
109+
fn from(config: crate::config::Config) -> Self {
110+
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
111+
}
90112
}
91113

92114
impl CachedResolver {
@@ -104,29 +126,47 @@ impl CachedResolver {
104126
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
105127
///
106128
/// # tokio_test::block_on(async {
107-
/// let config = CachedResolverConfig{dns_max_ttl: 10};
108-
/// let resolver = CachedResolver::new(config);
129+
/// let config = CachedResolverConfig::default();
130+
/// let resolver = CachedResolver::new(config, None).await.unwrap();
109131
/// # })
110132
/// ```
111133
///
112-
pub async fn new(config: CachedResolverConfig) -> io::Result<Arc<Self>> {
134+
pub async fn new(
135+
config: CachedResolverConfig,
136+
data: Option<HashMap<String, AddrSet>>,
137+
) -> Result<Arc<Self>, io::Error> {
113138
// Construct a new Resolver with default configuration options
114-
let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?);
115-
let data = Arc::new(RwLock::new(HashMap::new()));
139+
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
140+
141+
let data = if let Some(hash) = data {
142+
Some(RwLock::new(hash))
143+
} else {
144+
Some(RwLock::new(HashMap::new()))
145+
};
116146

117-
let self_ref = Arc::new(Self {
147+
let instance = Arc::new(Self {
118148
config,
119149
resolver,
120150
data,
151+
refresh_loop: RwLock::new(None),
121152
});
122-
let clone_self_ref = self_ref.clone();
123153

124-
info!("Scheduling DNS refresh loop");
125-
tokio::task::spawn(async move {
126-
clone_self_ref.refresh_dns_entries_loop().await;
127-
});
154+
if instance.enabled() {
155+
info!("Scheduling DNS refresh loop");
156+
let refresh_loop = tokio::task::spawn({
157+
let instance = instance.clone();
158+
async move {
159+
instance.refresh_dns_entries_loop().await;
160+
}
161+
});
162+
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
163+
}
128164

129-
Ok(self_ref)
165+
Ok(instance)
166+
}
167+
168+
pub fn enabled(&self) -> bool {
169+
self.config.enabled
130170
}
131171

132172
// Schedules the refresher
@@ -139,8 +179,10 @@ impl CachedResolver {
139179
// an array with keys.
140180
let mut hostnames: Vec<String> = Vec::new();
141181
{
142-
for hostname in self.data.read().unwrap().keys() {
143-
hostnames.push(hostname.clone());
182+
if let Some(ref data) = self.data {
183+
for hostname in data.read().unwrap().keys() {
184+
hostnames.push(hostname.clone());
185+
}
144186
}
145187
}
146188

@@ -193,8 +235,8 @@ impl CachedResolver {
193235
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
194236
///
195237
/// # tokio_test::block_on(async {
196-
/// let config = CachedResolverConfig { dns_max_ttl: 10 };
197-
/// let resolver = CachedResolver::new(config).await.unwrap();
238+
/// let config = CachedResolverConfig::default();
239+
/// let resolver = CachedResolver::new(config, None).await.unwrap();
198240
/// let response = resolver.lookup_ip("www.google.com.");
199241
/// # })
200242
/// ```
@@ -208,10 +250,14 @@ impl CachedResolver {
208250
}
209251
None => {
210252
debug!("Not found, executing a dns query!");
211-
let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?);
212-
debug!("Obtained: {:?}", addr_set);
213-
self.store_in_cache(host, addr_set.clone());
214-
Ok(addr_set)
253+
if let Some(ref resolver) = self.resolver {
254+
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
255+
debug!("Obtained: {:?}", addr_set);
256+
self.store_in_cache(host, addr_set.clone());
257+
Ok(addr_set)
258+
} else {
259+
Err(ResolveError::from("No resolver available"))
260+
}
215261
}
216262
}
217263
}
@@ -227,72 +273,89 @@ impl CachedResolver {
227273

228274
// Fetches an AddrSet from the inner cache adquiring the read lock.
229275
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
230-
let hash = &self.data.read().unwrap();
231-
if let Some(addr_set) = hash.get(key) {
232-
return Some(addr_set.clone());
276+
if let Some(ref hash) = self.data {
277+
if let Some(addr_set) = hash.read().unwrap().get(key) {
278+
return Some(addr_set.clone());
279+
}
233280
}
234281
None
235282
}
236283

237284
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
238285
// cache.
239286
pub async fn from_config() -> Result<(), Error> {
240-
let config = get_config();
287+
let cached_resolver = CACHED_RESOLVER.load();
288+
let desired_config = CachedResolverConfig::from(get_config());
241289

242-
// Configure dns_cache if enabled
243-
if config.general.dns_cache_enabled {
244-
info!("Starting Dns cache");
245-
let cached_resolver_config = CachedResolverConfig {
246-
dns_max_ttl: config.general.dns_max_ttl,
290+
if cached_resolver.config != desired_config {
291+
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
292+
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
293+
refresh_loop.abort()
294+
}
295+
let new_resolver = if let Some(ref data) = cached_resolver.data {
296+
let data = Some(data.read().unwrap().clone());
297+
CachedResolver::new(desired_config, data).await
298+
} else {
299+
CachedResolver::new(desired_config, None).await
247300
};
248-
return match CachedResolver::new(cached_resolver_config).await {
301+
302+
match new_resolver {
249303
Ok(ok) => {
250-
let value = Some(ArcSwap::from(ok));
251-
CACHED_RESOLVER.store(Arc::new(value));
304+
CACHED_RESOLVER.store(ok);
252305
Ok(())
253306
}
254307
Err(err) => {
255-
let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err);
308+
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
256309
Err(Error::DNSCachedError(message))
257310
}
258-
};
311+
}
312+
} else {
313+
Ok(())
259314
}
260-
Ok(())
261315
}
262316

263317
// Stores the AddrSet in cache adquiring the write lock.
264318
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
265-
self.data
266-
.write()
267-
.unwrap()
268-
.insert(host.to_string(), addr_set);
319+
if let Some(ref data) = self.data {
320+
data.write().unwrap().insert(host.to_string(), addr_set);
321+
} else {
322+
error!("Could not insert, Hash not initialized");
323+
}
269324
}
270325
}
271-
272326
#[cfg(test)]
273327
mod tests {
274328
use super::*;
275329
use trust_dns_resolver::error::ResolveError;
276330

277331
#[tokio::test]
278332
async fn new() {
279-
let config = CachedResolverConfig { dns_max_ttl: 10 };
280-
let resolver = CachedResolver::new(config).await;
333+
let config = CachedResolverConfig {
334+
dns_max_ttl: 10,
335+
enabled: true,
336+
};
337+
let resolver = CachedResolver::new(config, None).await;
281338
assert!(resolver.is_ok());
282339
}
283340

284341
#[tokio::test]
285342
async fn lookup_ip() {
286-
let config = CachedResolverConfig { dns_max_ttl: 10 };
287-
let resolver = CachedResolver::new(config).await.unwrap();
343+
let config = CachedResolverConfig {
344+
dns_max_ttl: 10,
345+
enabled: true,
346+
};
347+
let resolver = CachedResolver::new(config, None).await.unwrap();
288348
let response = resolver.lookup_ip("www.google.com.").await;
289349
assert!(response.is_ok());
290350
}
291351

292352
#[tokio::test]
293353
async fn has_changed() {
294-
let config = CachedResolverConfig { dns_max_ttl: 10 };
295-
let resolver = CachedResolver::new(config).await.unwrap();
354+
let config = CachedResolverConfig {
355+
dns_max_ttl: 10,
356+
enabled: true,
357+
};
358+
let resolver = CachedResolver::new(config, None).await.unwrap();
296359
let hostname = "www.google.com.";
297360
let response = resolver.lookup_ip(hostname).await;
298361
let addr_set = response.unwrap();
@@ -301,17 +364,23 @@ mod tests {
301364

302365
#[tokio::test]
303366
async fn unknown_host() {
304-
let config = CachedResolverConfig { dns_max_ttl: 10 };
305-
let resolver = CachedResolver::new(config).await.unwrap();
367+
let config = CachedResolverConfig {
368+
dns_max_ttl: 10,
369+
enabled: true,
370+
};
371+
let resolver = CachedResolver::new(config, None).await.unwrap();
306372
let hostname = "www.idontexists.";
307373
let response = resolver.lookup_ip(hostname).await;
308374
assert!(matches!(response, Err(ResolveError { .. })));
309375
}
310376

311377
#[tokio::test]
312378
async fn incorrect_address() {
313-
let config = CachedResolverConfig { dns_max_ttl: 10 };
314-
let resolver = CachedResolver::new(config).await.unwrap();
379+
let config = CachedResolverConfig {
380+
dns_max_ttl: 10,
381+
enabled: true,
382+
};
383+
let resolver = CachedResolver::new(config, None).await.unwrap();
315384
let hostname = "w ww.idontexists.";
316385
let response = resolver.lookup_ip(hostname).await;
317386
assert!(matches!(response, Err(ResolveError { .. })));
@@ -323,9 +392,11 @@ mod tests {
323392
// and does not responds with every available ip everytime, so
324393
// if I cache here, it will miss after one cache iteration or two.
325394
async fn thread() {
326-
env_logger::init();
327-
let config = CachedResolverConfig { dns_max_ttl: 10 };
328-
let resolver = CachedResolver::new(config).await.unwrap();
395+
let config = CachedResolverConfig {
396+
dns_max_ttl: 10,
397+
enabled: true,
398+
};
399+
let resolver = CachedResolver::new(config, None).await.unwrap();
329400
let hostname = "www.google.com.";
330401
let response = resolver.lookup_ip(hostname).await;
331402
let addr_set = response.unwrap();

0 commit comments

Comments
 (0)