diff --git a/src/lib.rs b/src/lib.rs index d65772983..fd5fd4912 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,12 +130,13 @@ use std::mem; use std::net::{ToSocketAddrs, IpAddr}; use std::ops::{Range, RangeFrom, RangeTo}; use std::path::{Path, PathBuf}; +use std::rc::Rc; use std::str; pub use origin::{Origin, OpaqueOrigin}; pub use host::{Host, HostAndPort, SocketAddrs}; pub use path_segments::PathSegmentsMut; -pub use parser::ParseError; +pub use parser::{ParseError, SyntaxViolation}; pub use slicing::Position; mod encoding; @@ -182,11 +183,11 @@ impl HeapSizeOf for Url { } /// Full configuration for the URL parser. -#[derive(Copy, Clone)] +#[derive(Clone)] pub struct ParseOptions<'a> { base_url: Option<&'a Url>, encoding_override: encoding::EncodingOverride, - log_syntax_violation: Option<&'a Fn(&'static str)>, + syntax_violation_callback: Option>, } impl<'a> ParseOptions<'a> { @@ -209,19 +210,61 @@ impl<'a> ParseOptions<'a> { self } - /// Call the provided function or closure on non-fatal parse errors. + /// Call the provided function or closure on non-fatal parse errors, passing + /// a static string description. This method is deprecated in favor of + /// `syntax_violation_callback` and is implemented as an adaptor for the + /// latter, passing the `SyntaxViolation` description. Only the last value + /// passed to either method will be used by a parser. + #[deprecated] pub fn log_syntax_violation(mut self, new: Option<&'a Fn(&'static str)>) -> Self { - self.log_syntax_violation = new; + self.syntax_violation_callback = match new { + Some(f) => Some(Rc::new(move |v: SyntaxViolation| f(v.description()))), + None => None + }; + self + } + + /// Call the provided function or closure for a non-fatal `SyntaxViolation` + /// when it occurs during parsing. Note that since the provided function is + /// `Fn`, the caller might need to utilize _interior mutability_, such as with + /// a `RefCell`, to collect the violations. + /// + /// ## Example + /// ``` + /// use std::cell::RefCell; + /// use url::{Url, SyntaxViolation}; + /// # use url::ParseError; + /// # fn run() -> Result<(), url::ParseError> { + /// let violations = RefCell::new(Vec::new()); + /// let url = Url::options(). + /// syntax_violation_callback(Some(|v| { + /// violations.borrow_mut().push(v) + /// })). + /// parse("https:////example.com")?; + /// assert_eq!(url.as_str(), "https://example.com/"); + /// assert_eq!(vec!(SyntaxViolation::ExpectedDoubleSlash), + /// violations.into_inner()); + /// # Ok(()) + /// # } + /// # run().unwrap(); + /// ``` + pub fn syntax_violation_callback(mut self, new: Option) -> Self + where F: Fn(SyntaxViolation) + 'a + { + self.syntax_violation_callback = match new { + Some(f) => Some(Rc::new(f)), + None => None + }; self } /// Parse an URL string with the configuration so far. - pub fn parse(self, input: &str) -> Result { + pub fn parse(&self, input: &str) -> Result { Parser { serialization: String::with_capacity(input.len()), base_url: self.base_url, query_encoding_override: self.encoding_override, - log_syntax_violation: self.log_syntax_violation, + syntax_violation_callback: self.syntax_violation_callback.clone(), context: Context::UrlParser, }.parse_url(input) } @@ -229,11 +272,14 @@ impl<'a> ParseOptions<'a> { impl<'a> Debug for ParseOptions<'a> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "ParseOptions {{ base_url: {:?}, encoding_override: {:?}, log_syntax_violation: ", self.base_url, self.encoding_override)?; - match self.log_syntax_violation { - Some(_) => write!(f, "Some(Fn(&'static str)) }}"), - None => write!(f, "None }}") - } + write!(f, "ParseOptions {{ base_url: {:?}, encoding_override: {:?}, \ + syntax_violation_callback: {} }}", + self.base_url, + self.encoding_override, + match self.syntax_violation_callback { + Some(_) => "Some(Fn(SyntaxViolation))", + None => "None" + }) } } @@ -363,7 +409,7 @@ impl Url { ParseOptions { base_url: None, encoding_override: EncodingOverride::utf8(), - log_syntax_violation: None, + syntax_violation_callback: None, } } diff --git a/src/parser.rs b/src/parser.rs index b16ecb7f6..c3192f6d8 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -9,6 +9,7 @@ use std::ascii::AsciiExt; use std::error::Error; use std::fmt::{self, Formatter, Write}; +use std::rc::Rc; use std::str; use Url; @@ -70,6 +71,55 @@ impl From<::idna::uts46::Errors> for ParseError { fn from(_: ::idna::uts46::Errors) -> ParseError { ParseError::IdnaError } } +macro_rules! syntax_violation_enum { + ($($name: ident => $description: expr,)+) => { + /// Non-fatal syntax violations that can occur during parsing. + #[derive(PartialEq, Eq, Clone, Copy, Debug)] + pub enum SyntaxViolation { + $( + $name, + )+ + } + + impl SyntaxViolation { + pub fn description(&self) -> &'static str { + match *self { + $( + SyntaxViolation::$name => $description, + )+ + } + } + } + } +} + +syntax_violation_enum! { + Backslash => "backslash", + C0SpaceIgnored => + "leading or trailing control or space character are ignored in URLs", + EmbeddedCredentials => + "embedding authentication information (username or password) \ + in an URL is not recommended", + ExpectedDoubleSlash => "expected //", + ExpectedFileDoubleSlash => "expected // after file:", + FileWithHostandWindowsDriveLetter => + "file: with host and Windows drive letter", + NonUrlCodePoint => "non-URL code point", + NullInFragment => "NULL characters are ignored in URL fragment identifiers", + PercentDecode => "expected 2 hex digits after %", + TabOrNewlineIgnored => "tabs or newlines are ignored in URLs", + UnencodedAtSign => "unencoded @ sign in username or password", +} + +#[cfg(feature = "heapsize")] +known_heap_size!(0, SyntaxViolation); + +impl fmt::Display for SyntaxViolation { + fn fmt(&self, fmt: &mut Formatter) -> fmt::Result { + self.description().fmt(fmt) + } +} + #[derive(Copy, Clone)] pub enum SchemeType { File, @@ -115,15 +165,15 @@ impl<'i> Input<'i> { Input::with_log(input, None) } - pub fn with_log(original_input: &'i str, log_syntax_violation: Option<&Fn(&'static str)>) + pub fn with_log(original_input: &'i str, callback: Option>) -> Self { let input = original_input.trim_matches(c0_control_or_space); - if let Some(log) = log_syntax_violation { + if let Some(cb) = callback { if input.len() < original_input.len() { - log("leading or trailing control or space character are ignored in URLs") + cb(SyntaxViolation::C0SpaceIgnored) } if input.chars().any(|c| matches!(c, '\t' | '\n' | '\r')) { - log("tabs or newlines are ignored in URLs") + cb(SyntaxViolation::TabOrNewlineIgnored) } } Input { chars: input.chars() } @@ -220,7 +270,7 @@ pub struct Parser<'a> { pub serialization: String, pub base_url: Option<&'a Url>, pub query_encoding_override: EncodingOverride, - pub log_syntax_violation: Option<&'a Fn(&'static str)>, + pub syntax_violation_callback: Option>, pub context: Context, } @@ -237,29 +287,29 @@ impl<'a> Parser<'a> { serialization: serialization, base_url: None, query_encoding_override: EncodingOverride::utf8(), - log_syntax_violation: None, + syntax_violation_callback: None, context: Context::Setter, } } - fn syntax_violation(&self, reason: &'static str) { - if let Some(log) = self.log_syntax_violation { - log(reason) + fn syntax_violation(&self, v: SyntaxViolation) { + if let Some(ref cb) = self.syntax_violation_callback { + cb(v) } } - fn syntax_violation_if bool>(&self, reason: &'static str, test: F) { + fn syntax_violation_if bool>(&self, v: SyntaxViolation, test: F) { // Skip test if not logging. - if let Some(log) = self.log_syntax_violation { + if let Some(ref cb) = self.syntax_violation_callback { if test() { - log(reason) + cb(v) } } } /// https://url.spec.whatwg.org/#concept-basic-url-parser pub fn parse_url(mut self, input: &str) -> ParseResult { - let input = Input::with_log(input, self.log_syntax_violation); + let input = Input::with_log(input, self.syntax_violation_callback.clone()); if let Ok(remaining) = self.parse_scheme(input.clone()) { return self.parse_with_scheme(remaining) } @@ -310,12 +360,13 @@ impl<'a> Parser<'a> { } fn parse_with_scheme(mut self, input: Input) -> ParseResult { + use SyntaxViolation::{ExpectedFileDoubleSlash, ExpectedDoubleSlash}; let scheme_end = to_u32(self.serialization.len())?; let scheme_type = SchemeType::from(&self.serialization); self.serialization.push(':'); match scheme_type { SchemeType::File => { - self.syntax_violation_if("expected // after file:", || !input.starts_with("//")); + self.syntax_violation_if(ExpectedFileDoubleSlash, || !input.starts_with("//")); let base_file_url = self.base_url.and_then(|base| { if base.scheme() == "file" { Some(base) } else { None } }); @@ -335,7 +386,7 @@ impl<'a> Parser<'a> { } } // special authority slashes state - self.syntax_violation_if("expected //", || { + self.syntax_violation_if(ExpectedDoubleSlash, || { input.clone().take_while(|&c| matches!(c, '/' | '\\')) .collect::() != "//" }); @@ -371,6 +422,7 @@ impl<'a> Parser<'a> { } fn parse_file(mut self, input: Input, mut base_file_url: Option<&Url>) -> ParseResult { + use SyntaxViolation::Backslash; // file state debug_assert!(self.serialization.is_empty()); let (first_char, input_after_first_char) = input.split_first(); @@ -468,10 +520,10 @@ impl<'a> Parser<'a> { } } Some('/') | Some('\\') => { - self.syntax_violation_if("backslash", || first_char == Some('\\')); + self.syntax_violation_if(Backslash, || first_char == Some('\\')); // file slash state let (next_char, input_after_next_char) = input_after_first_char.split_first(); - self.syntax_violation_if("backslash", || next_char == Some('\\')); + self.syntax_violation_if(Backslash, || next_char == Some('\\')); if matches!(next_char, Some('/') | Some('\\')) { // file host state self.serialization.push_str("file://"); @@ -623,7 +675,7 @@ impl<'a> Parser<'a> { Some('/') | Some('\\') => { let (slashes_count, remaining) = input.count_matching(|c| matches!(c, '/' | '\\')); if slashes_count >= 2 { - self.syntax_violation_if("expected //", || { + self.syntax_violation_if(SyntaxViolation::ExpectedDoubleSlash, || { input.clone().take_while(|&c| matches!(c, '/' | '\\')) .collect::() != "//" }); @@ -687,11 +739,9 @@ impl<'a> Parser<'a> { match c { '@' => { if last_at.is_some() { - self.syntax_violation("unencoded @ sign in username or password") + self.syntax_violation(SyntaxViolation::UnencodedAtSign) } else { - self.syntax_violation( - "embedding authentication information (username or password) \ - in an URL is not recommended") + self.syntax_violation(SyntaxViolation::EmbeddedCredentials) } last_at = Some((char_count, remaining.clone())) }, @@ -889,7 +939,7 @@ impl<'a> Parser<'a> { match input.split_first() { (Some('/'), remaining) => input = remaining, (Some('\\'), remaining) => if scheme_type.is_special() { - self.syntax_violation("backslash"); + self.syntax_violation(SyntaxViolation::Backslash); input = remaining }, _ => {} @@ -917,7 +967,7 @@ impl<'a> Parser<'a> { }, '\\' if self.context != Context::PathSegmentSetter && scheme_type.is_special() => { - self.syntax_violation("backslash"); + self.syntax_violation(SyntaxViolation::Backslash); ends_with_slash = true; break }, @@ -958,7 +1008,7 @@ impl<'a> Parser<'a> { self.serialization.push(':'); } if *has_host { - self.syntax_violation("file: with host and Windows drive letter"); + self.syntax_violation(SyntaxViolation::FileWithHostandWindowsDriveLetter); *has_host = false; // FIXME account for this in callers } } @@ -1100,7 +1150,7 @@ impl<'a> Parser<'a> { pub fn parse_fragment(&mut self, mut input: Input) { while let Some((c, utf8_c)) = input.next_utf8() { if c == '\0' { - self.syntax_violation("NULL characters are ignored in URL fragment identifiers") + self.syntax_violation(SyntaxViolation::NullInFragment) } else { self.check_url_code_point(c, &input); self.serialization.extend(utf8_percent_encode(utf8_c, @@ -1110,15 +1160,15 @@ impl<'a> Parser<'a> { } fn check_url_code_point(&self, c: char, input: &Input) { - if let Some(log) = self.log_syntax_violation { + if let Some(ref cb) = self.syntax_violation_callback { if c == '%' { let mut input = input.clone(); if !matches!((input.next(), input.next()), (Some(a), Some(b)) if is_ascii_hex_digit(a) && is_ascii_hex_digit(b)) { - log("expected 2 hex digits after %") + cb(SyntaxViolation::PercentDecode) } } else if !is_url_code_point(c) { - log("non-URL code point") + cb(SyntaxViolation::NonUrlCodePoint) } } } diff --git a/tests/unit.rs b/tests/unit.rs index b76a1f80d..af3200af6 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -13,6 +13,7 @@ extern crate url; use std::ascii::AsciiExt; use std::borrow::Cow; +use std::cell::{Cell, RefCell}; use std::net::{Ipv4Addr, Ipv6Addr}; use std::path::{Path, PathBuf}; use url::{Host, HostAndPort, Url, form_urlencoded}; @@ -477,3 +478,72 @@ fn test_windows_unc_path() { let url = Url::from_file_path(Path::new(r"\\.\some\path\file.txt")); assert!(url.is_err()); } + +// Test the now deprecated log_syntax_violation method for backward +// compatibility +#[test] +#[allow(deprecated)] +fn test_old_log_violation_option() { + let violation = Cell::new(None); + let url = { + let vfn = |s: &str| violation.set(Some(s.to_owned())); + let options = Url::options().log_syntax_violation(Some(&vfn)); + options.parse("http:////mozilla.org:42").unwrap() + }; + assert_eq!(url.port(), Some(42)); + + let violation = violation.take(); + assert_eq!(violation, Some("expected //".to_string())); +} + +#[test] +fn test_syntax_violation_callback() { + use url::SyntaxViolation::*; + let violation = Cell::new(None); + let url = Url::options(). + syntax_violation_callback(Some(|v| violation.set(Some(v)))). + parse("http:////mozilla.org:42"). + unwrap(); + assert_eq!(url.port(), Some(42)); + + let v = violation.take().unwrap(); + assert_eq!(v, ExpectedDoubleSlash); + assert_eq!(v.description(), "expected //"); +} + +#[test] +fn test_syntax_violation_callback_lifetimes() { + use url::SyntaxViolation::*; + let violation = Cell::new(None); + let vfn = |s| violation.set(Some(s)); + + let url = Url::options().syntax_violation_callback(Some(&vfn)). + parse("http:////mozilla.org:42"). + unwrap(); + assert_eq!(url.port(), Some(42)); + assert_eq!(violation.take(), Some(ExpectedDoubleSlash)); + + let url = Url::options().syntax_violation_callback(Some(&vfn)). + parse("http://mozilla.org\\path"). + unwrap(); + assert_eq!(url.path(), "/path"); + assert_eq!(violation.take(), Some(Backslash)); +} + +#[test] +fn test_options_reuse() { + use url::SyntaxViolation::*; + let violations = RefCell::new(Vec::new()); + + let options = Url::options(). + syntax_violation_callback(Some(|v| { + violations.borrow_mut().push(v); + })); + let url = options.parse("http:////mozilla.org").unwrap(); + + let options = options.base_url(Some(&url)); + let url = options.parse("/sub\\path").unwrap(); + assert_eq!(url.as_str(), "http://mozilla.org/sub/path"); + assert_eq!(*violations.borrow(), + vec!(ExpectedDoubleSlash, Backslash)); +}