Skip to content

Implement RFC 7231 compliant relative URI and fragment handling in redirects #13050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/uv-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,4 @@ hyper = { version = "1.4.1", features = ["server", "http1"] }
hyper-util = { version = "0.1.8", features = ["tokio"] }
insta = { version = "1.40.0", features = ["filters", "json", "redactions"] }
tokio = { workspace = true }
wiremock = { workspace = true }
267 changes: 258 additions & 9 deletions crates/uv-client/src/base_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ use std::sync::Arc;
use std::time::Duration;
use std::{env, iter};

use anyhow::anyhow;
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use itertools::Itertools;
use reqwest::{Client, ClientBuilder, Proxy, Response};
use reqwest::{multipart, Client, ClientBuilder, IntoUrl, Proxy, Request, Response};
use reqwest_middleware::{ClientWithMiddleware, Middleware};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{
DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy,
};
use tracing::{debug, trace};
use url::ParseError;
use url::Url;

use uv_auth::{AuthMiddleware, UrlAuthPolicies};
Expand Down Expand Up @@ -60,6 +63,24 @@ pub struct BaseClientBuilder<'a> {
default_timeout: Duration,
extra_middleware: Option<ExtraMiddleware>,
proxies: Vec<Proxy>,
redirect_policy: RedirectPolicy,
}

/// The policy for handling redirects.
#[derive(Debug, Default, Clone, Copy)]
pub enum RedirectPolicy {
#[default]
BypassMiddleware,
RetriggerMiddleware,
}

impl RedirectPolicy {
pub fn reqwest_policy(self) -> reqwest::redirect::Policy {
match self {
RedirectPolicy::BypassMiddleware => reqwest::redirect::Policy::default(),
RedirectPolicy::RetriggerMiddleware => reqwest::redirect::Policy::none(),
}
}
}

/// A list of user-defined middlewares to be applied to the client.
Expand Down Expand Up @@ -95,6 +116,7 @@ impl BaseClientBuilder<'_> {
default_timeout: Duration::from_secs(30),
extra_middleware: None,
proxies: vec![],
redirect_policy: RedirectPolicy::default(),
}
}
}
Expand Down Expand Up @@ -172,6 +194,12 @@ impl<'a> BaseClientBuilder<'a> {
self
}

#[must_use]
pub fn redirect(mut self, policy: RedirectPolicy) -> Self {
self.redirect_policy = policy;
self
}

pub fn is_offline(&self) -> bool {
matches!(self.connectivity, Connectivity::Offline)
}
Expand Down Expand Up @@ -228,6 +256,7 @@ impl<'a> BaseClientBuilder<'a> {
timeout,
ssl_cert_file_exists,
Security::Secure,
self.redirect_policy,
);

// Create an insecure client that accepts invalid certificates.
Expand All @@ -236,11 +265,18 @@ impl<'a> BaseClientBuilder<'a> {
timeout,
ssl_cert_file_exists,
Security::Insecure,
self.redirect_policy,
);

// Wrap in any relevant middleware and handle connectivity.
let client = self.apply_middleware(raw_client.clone());
let dangerous_client = self.apply_middleware(raw_dangerous_client.clone());
let client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_client.clone()),
redirect_policy: self.redirect_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
};

BaseClient {
connectivity: self.connectivity,
Expand All @@ -257,8 +293,14 @@ impl<'a> BaseClientBuilder<'a> {
/// Share the underlying client between two different middleware configurations.
pub fn wrap_existing(&self, existing: &BaseClient) -> BaseClient {
// Wrap in any relevant middleware and handle connectivity.
let client = self.apply_middleware(existing.raw_client.clone());
let dangerous_client = self.apply_middleware(existing.raw_dangerous_client.clone());
let client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_client.clone()),
redirect_policy: self.redirect_policy,
};
let dangerous_client = RedirectClientWithMiddleware {
client: self.apply_middleware(existing.raw_dangerous_client.clone()),
redirect_policy: self.redirect_policy,
};

BaseClient {
connectivity: self.connectivity,
Expand All @@ -278,14 +320,16 @@ impl<'a> BaseClientBuilder<'a> {
timeout: Duration,
ssl_cert_file_exists: bool,
security: Security,
redirect_policy: RedirectPolicy,
) -> Client {
// Configure the builder.
let client_builder = ClientBuilder::new()
.http1_title_case_headers()
.user_agent(user_agent)
.pool_max_idle_per_host(20)
.read_timeout(timeout)
.tls_built_in_root_certs(false);
.tls_built_in_root_certs(false)
.redirect(redirect_policy.reqwest_policy());

// If necessary, accept invalid certificates.
let client_builder = match security {
Expand Down Expand Up @@ -382,9 +426,9 @@ impl<'a> BaseClientBuilder<'a> {
#[derive(Debug, Clone)]
pub struct BaseClient {
/// The underlying HTTP client that enforces valid certificates.
client: ClientWithMiddleware,
client: RedirectClientWithMiddleware,
/// The underlying HTTP client that accepts invalid certificates.
dangerous_client: ClientWithMiddleware,
dangerous_client: RedirectClientWithMiddleware,
/// The HTTP client without middleware.
raw_client: Client,
/// The HTTP client that accepts invalid certificates without middleware.
Expand All @@ -409,14 +453,20 @@ enum Security {

impl BaseClient {
/// Selects the appropriate client based on the host's trustworthiness.
pub fn for_host(&self, url: &Url) -> &ClientWithMiddleware {
pub fn for_host(&self, url: &Url) -> &RedirectClientWithMiddleware {
if self.disable_ssl(url) {
&self.dangerous_client
} else {
&self.client
}
}

/// Executes a request, applying redirect policy.
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
let client = self.for_host(req.url());
client.execute(req).await
}

/// Returns `true` if the host is trusted to use the insecure client.
pub fn disable_ssl(&self, url: &Url) -> bool {
self.allow_insecure_host
Expand All @@ -440,6 +490,205 @@ impl BaseClient {
}
}

/// Wrapper around [`ClientWithMiddleware`] that manages redirects.
#[derive(Debug, Clone)]
pub struct RedirectClientWithMiddleware {
client: ClientWithMiddleware,
redirect_policy: RedirectPolicy,
}

impl RedirectClientWithMiddleware {
/// Convenience method to make a `GET` request to a URL.
pub fn get<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.get(url), self)
}

/// Convenience method to make a `POST` request to a URL.
pub fn post<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.post(url), self)
}

/// Convenience method to make a `HEAD` request to a URL.
pub fn head<U: IntoUrl>(&self, url: U) -> RequestBuilder {
RequestBuilder::new(self.client.head(url), self)
}

/// Executes a request, applying the redirect policy.
pub async fn execute(&self, req: Request) -> reqwest_middleware::Result<Response> {
match self.redirect_policy {
RedirectPolicy::BypassMiddleware => self.client.execute(req).await,
RedirectPolicy::RetriggerMiddleware => self.execute_with_redirect_handling(req).await,
}
}

/// Executes a request. If the response is a redirect (one of HTTP 301, 302, 307, or 308), the
/// request is executed again with the redirect location URL (up to a maximum number of
/// redirects).
///
/// Unlike the built-in reqwest redirect policies, this sends the redirect request through the
/// entire middleware pipeline again.
///
/// See RFC 7231 7.1.2 <https://www.rfc-editor.org/rfc/rfc7231#section-7.1.2> for details on
/// redirect semantics.
async fn execute_with_redirect_handling(
&self,
req: Request,
) -> reqwest_middleware::Result<Response> {
let mut request = req;
let mut redirects = 0;
// This is the default used by reqwest.
let max_redirects = 10;

loop {
let request_url = request.url().clone();
let result = self
.client
.execute(request.try_clone().expect("HTTP request must be cloneable"))
.await;
if redirects == max_redirects {
return result;
}
let Ok(response) = result else {
return result;
};

// Handle redirect if we receive a 301, 302, 307, or 308.
let status = response.status();
if matches!(
status,
StatusCode::MOVED_PERMANENTLY
| StatusCode::FOUND
| StatusCode::TEMPORARY_REDIRECT
| StatusCode::PERMANENT_REDIRECT
) {
let location = response
.headers()
.get("location")
.ok_or(reqwest_middleware::Error::Middleware(anyhow!(
"Missing expected HTTP {status} 'Location' header"
)))?
.to_str()
.map_err(|_| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value: must only contain visible ascii characters"
))
})?;

let mut redirect_url = match Url::parse(location) {
Ok(url) => url,
// Per RFC 7231, URLs should be resolved against the request URL.
Err(ParseError::RelativeUrlWithoutBase) => request_url.join(location).map_err(|err| {
reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}` relative to `{request_url}`: {err}"
))
})?,
Err(err) => {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}`: {err}"
)));
}
};

// Ensure the URL is a valid HTTP URI.
if let Err(err) = redirect_url.as_str().parse::<http::Uri>() {
return Err(reqwest_middleware::Error::Middleware(anyhow!(
"Invalid HTTP {status} 'Location' value `{location}`: {err}"
)));
}

// Per RFC 7231, fragments must be propagated
if let Some(fragment) = request_url.fragment() {
redirect_url.set_fragment(Some(fragment));
}

debug!("Received HTTP {status} to {redirect_url}");
*request.url_mut() = redirect_url;
redirects += 1;
continue;
}

return Ok(response);
}
}

pub fn raw_client(&self) -> &ClientWithMiddleware {
&self.client
}
}

impl From<RedirectClientWithMiddleware> for ClientWithMiddleware {
fn from(item: RedirectClientWithMiddleware) -> ClientWithMiddleware {
item.client
}
}

/// A builder to construct the properties of a `Request`.
///
/// This wraps [`reqwest_middleware::RequestBuilder`] to ensure that the [`BaseClient`]
/// redirect policy is respected if `send()` is called.
#[derive(Debug)]
#[must_use]
pub struct RequestBuilder<'a> {
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
}

impl<'a> RequestBuilder<'a> {
pub fn new(
builder: reqwest_middleware::RequestBuilder,
client: &'a RedirectClientWithMiddleware,
) -> Self {
Self { builder, client }
}

/// Add a `Header` to this Request.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}

/// Add a set of Headers to the existing ones on this Request.
///
/// The headers will be merged in to any already set.
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.builder = self.builder.headers(headers);
self
}

#[cfg(not(target_arch = "wasm32"))]
pub fn version(mut self, version: reqwest::Version) -> Self {
self.builder = self.builder.version(version);
self
}

#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
pub fn multipart(mut self, multipart: multipart::Form) -> Self {
self.builder = self.builder.multipart(multipart);
self
}

/// Build a `Request`.
pub fn build(self) -> reqwest::Result<Request> {
self.builder.build()
}

/// Constructs the Request and sends it to the target URL, returning a
/// future Response.
pub async fn send(self) -> reqwest_middleware::Result<Response> {
self.client.execute(self.build()?).await
}

pub fn raw_builder(&self) -> &reqwest_middleware::RequestBuilder {
&self.builder
}
}

/// Extends [`DefaultRetryableStrategy`], to log transient request failures and additional retry cases.
pub struct UvRetryableStrategy;

Expand Down
2 changes: 0 additions & 2 deletions crates/uv-client/src/cached_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,6 @@ impl CachedClient {
debug!("Sending revalidation request for: {url}");
let response = self
.0
.for_host(req.url())
.execute(req)
.instrument(info_span!("revalidation_request", url = url.as_str()))
.await
Expand Down Expand Up @@ -551,7 +550,6 @@ impl CachedClient {
let cache_policy_builder = CachePolicyBuilder::new(&req);
let response = self
.0
.for_host(&url)
.execute(req)
.await
.map_err(|err| ErrorKind::from_reqwest_middleware(url.clone(), err))?
Expand Down
Loading
Loading