diff --git a/Cargo.lock b/Cargo.lock index f75b2f3..fee9e3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -154,7 +154,7 @@ dependencies = [ [[package]] name = "balatro_mod_index" -version = "0.3.0" +version = "0.3.1" dependencies = [ "bytes", "cached", diff --git a/Cargo.toml b/Cargo.toml index 8b56a5b..788c036 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "balatro_mod_index" description = "a library for parsing a git lfs repo into a BMM-compatible index" -version = "0.3.0" +version = "0.3.1" edition = "2024" license = "GPL-3.0-only" exclude = [".gitignore", "flake.*"] diff --git a/src/lfs.rs b/src/lfs.rs index 20572a6..41b98e4 100644 --- a/src/lfs.rs +++ b/src/lfs.rs @@ -77,20 +77,15 @@ pub async fn mut_fetch_download_urls( concurrency_factor: usize, refresh_available: bool, ) -> Result<(), String> { - use futures::{StreamExt, stream}; use std::cmp::min; let tree = blobs.first().ok_or("no blobs to fetch")?.tree; - let pointers = if refresh_available { - blobs.iter().map(|b| &b.pointer).collect::>() - } else { - blobs - .iter() - .filter(|b| b.url.is_none()) - .map(|b| &b.pointer) - .collect::>() - }; + let pointers = blobs + .iter() + .filter(|b| refresh_available || b.url.is_none()) + .map(|b| &b.pointer) + .collect::>(); if pointers.is_empty() { log::debug!("no lfs info to fetch"); return Ok(()); @@ -111,8 +106,9 @@ pub async fn mut_fetch_download_urls( let batch = &pointers[offset..next]; - let future = async move { + tasks.push(async move { log::debug!("getting lfs object info at offset {offset}"); + let resp = client .post(format!( "https://{}/{}/{}.git/info/lfs/objects/batch", @@ -134,35 +130,31 @@ pub async fn mut_fetch_download_urls( .map_err(|e| format!("couldn't read raw response: {e}"))? .to_vec(); - let data: BatchResponse = serde_json::from_slice(&data) + let data = serde_json::from_slice::(&data) .map_err(|_| match String::from_utf8(data) { Ok(s) => format!("response was not json, but a string: {s}"), Err(e) => format!("response was not json, and not a valid utf-8 string: {e}"), }) .map_err(|e| format!("couldn't parse response: {e}"))?; - Ok(data - .objects - .into_iter() - .map(|obj| obj.actions.download.href) - .collect::>()) - }; - tasks.push(future); + Ok::<_, String>( + data.objects + .into_iter() + .map(|obj| obj.actions.download.href), + ) + }); offset = next; } - let download_urls = stream::iter(tasks) - .buffer_unordered(concurrency_factor) - .collect::, String>>>() + let download_urls = buffer_unordered(tasks, concurrency_factor) .await .into_iter() - .map(Result::ok) + .map(|r| r.map_err(|e| format!("couldn't fetch download urls: {e}"))) .try_fold(Vec::new(), |mut acc, result| { acc.extend(result?); - Some(acc) - }) - .ok_or("couldn't fetch download urls")?; + Ok::<_, String>(acc) + })?; for (blob, url) in blobs.iter_mut().zip(download_urls) { blob.url = Some(url); @@ -176,19 +168,52 @@ pub async fn mut_fetch_blobs( blobs: &mut [&mut Blob<'_>], client: &reqwest::Client, concurrency_factor: usize, -) -> Result<(), String> { - use futures::{StreamExt, stream}; - - stream::iter(blobs.iter_mut().filter_map(|b| { - b.url.as_ref().map(|url| async { - b.data = fetch_one(client, url, &b.pointer.oid).await; - }) - })) - .buffer_unordered(concurrency_factor) - .collect::>() +) -> () { + let thumbnails = buffer_unordered( + blobs.iter().filter_map(|b| { + b.url + .as_ref() + .map(|url| async move { fetch_one(client, url, &b.pointer.oid).await }) + }), + concurrency_factor, + ) .await; - Ok(()) + for (blob, data) in blobs.iter_mut().zip(thumbnails) { + blob.data = data; + } +} + +#[cfg(feature = "reqwest")] +async fn buffer_unordered(tasks: I, concurrency_factor: usize) -> Vec +where + Fut: std::future::Future, + I: IntoIterator, +{ + use futures::{StreamExt, stream::FuturesUnordered}; + + let mut tasks = tasks + .into_iter() + .enumerate() + .map(|(i, task)| async move { (i, task.await) }) + .collect::>(); + + let mut results = Vec::new(); + let mut futures = FuturesUnordered::new(); + while let Some(next_job) = tasks.pop() { + while futures.len() >= concurrency_factor { + if let Some(result) = futures.next().await { + results.push(result); + } + } + futures.push(next_job); + } + while let Some(result) = futures.next().await { + results.push(result); + } + + results.sort_by_key(|(i, _)| *i); + results.into_iter().map(|(_, result)| result).collect() } #[cfg(feature = "reqwest")] diff --git a/src/mods.rs b/src/mods.rs index ed75a4c..c32117e 100644 --- a/src/mods.rs +++ b/src/mods.rs @@ -189,7 +189,7 @@ impl ModIndex<'_> { .collect::>(); lfs::mut_fetch_download_urls(blobs, client, concurrency_factor, refresh_urls).await?; - lfs::mut_fetch_blobs(blobs, client, concurrency_factor).await?; + lfs::mut_fetch_blobs(blobs, client, concurrency_factor).await; Ok(next) }