Skip to content

Commit e2f1aa2

Browse files
committed
Add dns_cache so server addresses are cached and invalidated when DNS changes.
Adds a module to deal with dns_cache feature. It's main struct is CachedResolver, which is a simple thread safe hostname <-> Ips cache with the ability to refresh resolutions every `dns_max_ttl` seconds. This way, a client can check whether its ip address has changed.
1 parent 62b2d99 commit e2f1aa2

File tree

8 files changed

+711
-3
lines changed

8 files changed

+711
-3
lines changed

Cargo.lock

+298
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ nix = "0.26.2"
3939
atomic_enum = "0.2.0"
4040
postgres-protocol = "0.6.5"
4141
fallible-iterator = "0.2"
42+
trust-dns-resolver = "0.22.0"
43+
tokio-test = "0.4.2"
4244

4345
[target.'cfg(not(target_env = "msvc"))'.dependencies]
4446
jemallocator = "0.5.0"

src/config.rs

+12
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ pub struct General {
231231
#[serde(default)] // False
232232
pub log_client_disconnections: bool,
233233

234+
#[serde(default)] // False
235+
pub dns_cache_enabled: bool,
236+
237+
#[serde(default = "General::default_dns_max_ttl")]
238+
pub dns_max_ttl: u64,
239+
234240
#[serde(default = "General::default_shutdown_timeout")]
235241
pub shutdown_timeout: u64,
236242

@@ -298,6 +304,10 @@ impl General {
298304
60000
299305
}
300306

307+
pub fn default_dns_max_ttl() -> u64 {
308+
30
309+
}
310+
301311
pub fn default_healthcheck_timeout() -> u64 {
302312
1000
303313
}
@@ -340,6 +350,8 @@ impl Default for General {
340350
log_client_connections: false,
341351
log_client_disconnections: false,
342352
autoreload: None,
353+
dns_cache_enabled: false,
354+
dns_max_ttl: Self::default_dns_max_ttl(),
343355
tls_certificate: None,
344356
tls_private_key: None,
345357
admin_username: String::from("admin"),

src/dns_cache.rs

+339
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
use crate::config::get_config;
2+
use crate::errors::Error;
3+
use arc_swap::ArcSwap;
4+
use log::{debug, error, info};
5+
use once_cell::sync::Lazy;
6+
use std::collections::{HashMap, HashSet};
7+
use std::io;
8+
use std::net::IpAddr;
9+
use std::sync::Arc;
10+
use std::sync::RwLock;
11+
use tokio::time::{sleep, Duration};
12+
use trust_dns_resolver::error::ResolveResult;
13+
use trust_dns_resolver::lookup_ip::LookupIp;
14+
use trust_dns_resolver::TokioAsyncResolver;
15+
16+
/// Cached Resolver Globally available
17+
pub static CACHED_RESOLVER: Lazy<ArcSwap<Option<ArcSwap<CachedResolver>>>> =
18+
Lazy::new(|| ArcSwap::from_pointee(None));
19+
20+
// Ip addressed are returned as a set of addresses
21+
// so we can compare.
22+
#[derive(Clone, PartialEq, Debug)]
23+
pub struct AddrSet {
24+
set: HashSet<IpAddr>,
25+
}
26+
27+
impl AddrSet {
28+
fn new() -> AddrSet {
29+
AddrSet {
30+
set: HashSet::new(),
31+
}
32+
}
33+
}
34+
35+
impl From<LookupIp> for AddrSet {
36+
fn from(lookup_ip: LookupIp) -> Self {
37+
let mut addr_set = AddrSet::new();
38+
for address in lookup_ip.iter() {
39+
addr_set.set.insert(address);
40+
}
41+
addr_set
42+
}
43+
}
44+
45+
///
46+
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
47+
///
48+
/// The system works as follows:
49+
///
50+
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
51+
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
52+
/// cache is refreshed.
53+
///
54+
/// # Example:
55+
///
56+
/// ```
57+
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
58+
///
59+
/// # tokio_test::block_on(async {
60+
/// let config = CachedResolverConfig{dns_max_ttl: 10};
61+
/// let resolver = CachedResolver::new(config).await.unwrap();
62+
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
63+
/// # })
64+
/// ```
65+
///
66+
/// // Now the ip resolution is stored in local cache and subsequent
67+
/// // calls will be returned from cache. Also, the cache is refreshed
68+
/// // and updated every 10 seconds.
69+
///
70+
/// // You can now check if an 'old' lookup differs from what it's currently
71+
/// // store in cache by using `has_changed`.
72+
/// resolver.has_changed("www.example.com.", addrset)
73+
pub struct CachedResolver {
74+
// The configuration of the cached_resolver.
75+
config: CachedResolverConfig,
76+
77+
// This is the hash that contains the hash.
78+
data: Arc<RwLock<HashMap<String, AddrSet>>>,
79+
80+
// The resolver to be used for DNS queries.
81+
resolver: Arc<TokioAsyncResolver>,
82+
}
83+
84+
///
85+
/// Configuration
86+
#[derive(Clone, Debug)]
87+
pub struct CachedResolverConfig {
88+
/// Amount of time in secods that a resolved dns address is considered stale.
89+
pub dns_max_ttl: u64,
90+
}
91+
92+
impl CachedResolver {
93+
///
94+
/// Returns a new Arc<CachedResolver> based on passed configuration.
95+
/// It also starts the loop that will refresh cache entries.
96+
///
97+
/// # Arguments:
98+
///
99+
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
100+
///
101+
/// # Example:
102+
///
103+
/// ```
104+
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
105+
///
106+
/// # tokio_test::block_on(async {
107+
/// let config = CachedResolverConfig{dns_max_ttl: 10};
108+
/// let resolver = CachedResolver::new(config);
109+
/// # })
110+
/// ```
111+
///
112+
pub async fn new(config: CachedResolverConfig) -> io::Result<Arc<Self>> {
113+
// Construct a new Resolver with default configuration options
114+
let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?);
115+
let data = Arc::new(RwLock::new(HashMap::new()));
116+
117+
let self_ref = Arc::new(Self {
118+
config,
119+
resolver,
120+
data,
121+
});
122+
let clone_self_ref = self_ref.clone();
123+
124+
info!("Scheduling DNS refresh loop");
125+
tokio::task::spawn(async move {
126+
clone_self_ref.refresh_dns_entries_loop().await;
127+
});
128+
129+
Ok(self_ref)
130+
}
131+
132+
// Schedules the refresher
133+
async fn refresh_dns_entries_loop(&self) {
134+
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
135+
let interval = Duration::from_secs(self.config.dns_max_ttl);
136+
loop {
137+
debug!("Begin refreshing cached DNS addresses.");
138+
// To minimize the time we hold the lock, we first create
139+
// an array with keys.
140+
let mut hostnames: Vec<String> = Vec::new();
141+
{
142+
for hostname in self.data.read().unwrap().keys() {
143+
hostnames.push(hostname.clone());
144+
}
145+
}
146+
147+
for hostname in hostnames.iter() {
148+
let addrset = self
149+
.fetch_from_cache(hostname.as_str())
150+
.expect("Could not obtain expected address from cache, this should not happen");
151+
152+
match resolver.lookup_ip(hostname).await {
153+
Ok(lookup_ip) => {
154+
let new_addrset = AddrSet::from(lookup_ip);
155+
debug!(
156+
"Obtained address for host ({}) -> ({:?})",
157+
hostname, new_addrset
158+
);
159+
160+
if addrset != new_addrset {
161+
debug!(
162+
"Addr changed from {:?} to {:?} updating cache.",
163+
addrset, new_addrset
164+
);
165+
self.store_in_cache(hostname, new_addrset);
166+
}
167+
}
168+
Err(err) => {
169+
error!(
170+
"There was an error trying to resolv {}: ({}).",
171+
hostname, err
172+
);
173+
}
174+
}
175+
}
176+
debug!("Finished refreshing cached DNS addresses.");
177+
sleep(interval).await;
178+
}
179+
}
180+
181+
/// Returns a `AddrSet` given the specified hostname.
182+
///
183+
/// This method first tries to fetch the value from the cache, if it misses
184+
/// then it is resolved and stored in the cache. TTL from records is ignored.
185+
///
186+
/// # Arguments
187+
///
188+
/// * `host` - A string slice referencing the hostname to be resolved.
189+
///
190+
/// # Example:
191+
///
192+
/// ```
193+
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
194+
///
195+
/// # tokio_test::block_on(async {
196+
/// let config = CachedResolverConfig { dns_max_ttl: 10 };
197+
/// let resolver = CachedResolver::new(config).await.unwrap();
198+
/// let response = resolver.lookup_ip("www.google.com.");
199+
/// # })
200+
/// ```
201+
///
202+
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
203+
debug!("Lookup up {} in cache", host);
204+
match self.fetch_from_cache(host) {
205+
Some(addr_set) => {
206+
debug!("Cache hit!");
207+
Ok(addr_set)
208+
}
209+
None => {
210+
debug!("Not found, executing a dns query!");
211+
let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?);
212+
debug!("Obtained: {:?}", addr_set);
213+
self.store_in_cache(host, addr_set.clone());
214+
Ok(addr_set)
215+
}
216+
}
217+
}
218+
219+
//
220+
// Returns true if the stored host resolution differs from the AddrSet passed.
221+
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
222+
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
223+
return fetched_addr_set != *addr_set;
224+
}
225+
false
226+
}
227+
228+
// Fetches an AddrSet from the inner cache adquiring the read lock.
229+
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
230+
let hash = &self.data.read().unwrap();
231+
if let Some(addr_set) = hash.get(key) {
232+
return Some(addr_set.clone());
233+
}
234+
None
235+
}
236+
237+
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
238+
// cache.
239+
pub async fn from_config() -> Result<(), Error> {
240+
let config = get_config();
241+
242+
// Configure dns_cache if enabled
243+
if config.general.dns_cache_enabled {
244+
info!("Starting Dns cache");
245+
let cached_resolver_config = CachedResolverConfig {
246+
dns_max_ttl: config.general.dns_max_ttl,
247+
};
248+
return match CachedResolver::new(cached_resolver_config).await {
249+
Ok(ok) => {
250+
let value = Some(ArcSwap::from(ok));
251+
CACHED_RESOLVER.store(Arc::new(value));
252+
Ok(())
253+
}
254+
Err(err) => {
255+
let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err);
256+
Err(Error::DNSCachedError(message))
257+
}
258+
};
259+
}
260+
Ok(())
261+
}
262+
263+
// Stores the AddrSet in cache adquiring the write lock.
264+
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
265+
self.data
266+
.write()
267+
.unwrap()
268+
.insert(host.to_string(), addr_set);
269+
}
270+
}
271+
272+
#[cfg(test)]
273+
mod tests {
274+
use super::*;
275+
use trust_dns_resolver::error::ResolveError;
276+
277+
#[tokio::test]
278+
async fn new() {
279+
let config = CachedResolverConfig { dns_max_ttl: 10 };
280+
let resolver = CachedResolver::new(config).await;
281+
assert!(resolver.is_ok());
282+
}
283+
284+
#[tokio::test]
285+
async fn lookup_ip() {
286+
let config = CachedResolverConfig { dns_max_ttl: 10 };
287+
let resolver = CachedResolver::new(config).await.unwrap();
288+
let response = resolver.lookup_ip("www.google.com.").await;
289+
assert!(response.is_ok());
290+
}
291+
292+
#[tokio::test]
293+
async fn has_changed() {
294+
let config = CachedResolverConfig { dns_max_ttl: 10 };
295+
let resolver = CachedResolver::new(config).await.unwrap();
296+
let hostname = "www.google.com.";
297+
let response = resolver.lookup_ip(hostname).await;
298+
let addr_set = response.unwrap();
299+
assert!(!resolver.has_changed(hostname, &addr_set));
300+
}
301+
302+
#[tokio::test]
303+
async fn unknown_host() {
304+
let config = CachedResolverConfig { dns_max_ttl: 10 };
305+
let resolver = CachedResolver::new(config).await.unwrap();
306+
let hostname = "www.idontexists.";
307+
let response = resolver.lookup_ip(hostname).await;
308+
assert!(matches!(response, Err(ResolveError { .. })));
309+
}
310+
311+
#[tokio::test]
312+
async fn incorrect_address() {
313+
let config = CachedResolverConfig { dns_max_ttl: 10 };
314+
let resolver = CachedResolver::new(config).await.unwrap();
315+
let hostname = "w ww.idontexists.";
316+
let response = resolver.lookup_ip(hostname).await;
317+
assert!(matches!(response, Err(ResolveError { .. })));
318+
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
319+
}
320+
321+
#[tokio::test]
322+
// Ok, this test is based on the fact that google does DNS RR
323+
// and does not responds with every available ip everytime, so
324+
// if I cache here, it will miss after one cache iteration or two.
325+
async fn thread() {
326+
env_logger::init();
327+
let config = CachedResolverConfig { dns_max_ttl: 10 };
328+
let resolver = CachedResolver::new(config).await.unwrap();
329+
let hostname = "www.google.com.";
330+
let response = resolver.lookup_ip(hostname).await;
331+
let addr_set = response.unwrap();
332+
assert!(!resolver.has_changed(hostname, &addr_set));
333+
let resolver_for_refresher = resolver.clone();
334+
let _thread_handle = tokio::task::spawn(async move {
335+
resolver_for_refresher.refresh_dns_entries_loop().await;
336+
});
337+
assert!(!resolver.has_changed(hostname, &addr_set));
338+
}
339+
}

src/errors.rs

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub enum Error {
1919
ClientError(String),
2020
TlsError,
2121
StatementTimeout,
22+
DNSCachedError(String),
2223
ShuttingDown,
2324
ParseBytesError(String),
2425
AuthError(String),

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
pub mod auth_passthrough;
22
pub mod config;
33
pub mod constants;
4+
pub mod dns_cache;
45
pub mod errors;
56
pub mod messages;
67
pub mod mirrors;

0 commit comments

Comments
 (0)