Skip to content

refactor: Keep parsed sqlx-data.json in a cache instead of reparsing #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 72 additions & 78 deletions sqlx-macros/src/query/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,43 @@ pub mod offline {
use super::QueryData;
use crate::database::DatabaseExt;

use std::fmt::{self, Formatter};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use std::collections::BTreeMap;
use std::fs::{self, File};
use std::io::BufWriter;
use std::path::{Path, PathBuf};
use std::sync::Mutex;

use once_cell::sync::Lazy;
use proc_macro2::Span;
use serde::de::{Deserializer, IgnoredAny, MapAccess, Visitor};
use sqlx_core::describe::Describe;

static OFFLINE_DATA_CACHE: Lazy<Mutex<BTreeMap<PathBuf, OfflineData>>> =
Lazy::new(|| Mutex::new(BTreeMap::new()));

#[derive(serde::Deserialize)]
struct BaseQuery {
query: String,
describe: serde_json::Value,
}

#[derive(serde::Deserialize)]
struct OfflineData {
db: String,
#[serde(flatten)]
hash_to_query: BTreeMap<String, BaseQuery>,
}

impl OfflineData {
fn get_query_from_hash(&self, hash: &str) -> Option<DynQueryData> {
self.hash_to_query.get(hash).map(|base_query| DynQueryData {
db_name: self.db.clone(),
query: base_query.query.to_owned(),
describe: base_query.describe.to_owned(),
hash: hash.to_owned(),
})
}
}

#[derive(serde::Deserialize)]
pub struct DynQueryData {
#[serde(skip)]
Expand All @@ -61,15 +89,44 @@ pub mod offline {
/// Find and deserialize the data table for this query from a shared `sqlx-data.json`
/// file. The expected structure is a JSON map keyed by the SHA-256 hash of queries in hex.
pub fn from_data_file(path: impl AsRef<Path>, query: &str) -> crate::Result<Self> {
let this = serde_json::Deserializer::from_reader(BufReader::new(
File::open(path.as_ref()).map_err(|e| {
format!("failed to open path {}: {}", path.as_ref().display(), e)
})?,
))
.deserialize_map(DataFileVisitor {
query,
hash: hash_string(query),
})?;
let path = path.as_ref();

let query_data = {
let mut cache = OFFLINE_DATA_CACHE
.lock()
// Just reset the cache on error
.unwrap_or_else(|posion_err| {
let mut guard = posion_err.into_inner();
*guard = BTreeMap::new();
guard
});

if !cache.contains_key(path) {
let offline_data_contents = fs::read_to_string(path)
.map_err(|e| format!("failed to read path {}: {}", path.display(), e))?;
let offline_data: OfflineData = serde_json::from_str(&offline_data_contents)?;
let _ = cache.insert(path.to_owned(), offline_data);
}

let offline_data = cache
.get(path)
.expect("Missing data should have just been added");

let query_hash = hash_string(query);
let query_data = offline_data
.get_query_from_hash(&query_hash)
.ok_or_else(|| format!("failed to find data for query {}", query))?;

if query != query_data.query {
return Err(format!(
"hash collision for stored queryies:\n{:?}\n{:?}",
query, query_data.query
)
.into());
}

query_data
};

#[cfg(procmacr2_semver_exempt)]
{
Expand All @@ -84,7 +141,7 @@ pub mod offline {
proc_macro::tracked_path::path(path);
}

Ok(this)
Ok(query_data)
}
}

Expand Down Expand Up @@ -138,67 +195,4 @@ pub mod offline {

hex::encode(Sha256::digest(query.as_bytes()))
}

// lazily deserializes only the `QueryData` for the query we're looking for
struct DataFileVisitor<'a> {
query: &'a str,
hash: String,
}

impl<'de> Visitor<'de> for DataFileVisitor<'_> {
type Value = DynQueryData;

fn expecting(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "expected map key {:?} or \"db\"", self.hash)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, <A as MapAccess<'de>>::Error>
where
A: MapAccess<'de>,
{
let mut db_name: Option<String> = None;

let query_data = loop {
// unfortunately we can't avoid this copy because deserializing from `io::Read`
// doesn't support deserializing borrowed values
let key = map.next_key::<String>()?.ok_or_else(|| {
serde::de::Error::custom(format_args!(
"failed to find data for query {}",
self.hash
))
})?;

// lazily deserialize the query data only
if key == "db" {
db_name = Some(map.next_value::<String>()?);
} else if key == self.hash {
let db_name = db_name.ok_or_else(|| {
serde::de::Error::custom("expected \"db\" key before query hash keys")
})?;

let mut query_data: DynQueryData = map.next_value()?;

if query_data.query == self.query {
query_data.db_name = db_name;
query_data.hash = self.hash.clone();
break query_data;
} else {
return Err(serde::de::Error::custom(format_args!(
"hash collision for stored queries:\n{:?}\n{:?}",
self.query, query_data.query
)));
};
} else {
// we don't care about entries that don't match our hash
let _ = map.next_value::<IgnoredAny>()?;
}
};

// Serde expects us to consume the whole map; fortunately they've got a convenient
// type to let us do just that
while let Some(_) = map.next_entry::<IgnoredAny, IgnoredAny>()? {}

Ok(query_data)
}
}
}