diff --git a/url_serde/src/lib.rs b/url_serde/src/lib.rs index 3c8d110da..11f8c203f 100644 --- a/url_serde/src/lib.rs +++ b/url_serde/src/lib.rs @@ -61,13 +61,14 @@ extern crate serde; extern crate url; use serde::{Deserialize, Serialize, Serializer, Deserializer}; +use serde::de::{EnumAccess, Unexpected, VariantAccess, Visitor, self}; use std::cmp::PartialEq; use std::error::Error; use std::fmt; use std::io::Write; use std::ops::{Deref, DerefMut}; use std::str; -use url::{Url, Host}; +use url::{Url, Host, ParseError}; /// Serialises `value` with a given serializer. /// @@ -113,6 +114,34 @@ impl<'a> Serialize for Ser<'a, Option> { } } +/// Serializes this ParseError into a `serde` stream. +impl<'a> Serialize for Ser<'a, ParseError> { + fn serialize(&self, serializer: S) -> Result where S: Serializer { + match *self.0 { + ParseError::EmptyHost => + serializer.serialize_unit_variant("ParseError", 0, "EmptyHost"), + ParseError::IdnaError => + serializer.serialize_unit_variant("ParseError", 1, "IdnaError"), + ParseError::InvalidPort => + serializer.serialize_unit_variant("ParseError", 2, "InvalidPort"), + ParseError::InvalidIpv4Address => + serializer.serialize_unit_variant("ParseError", 3, "InvalidIpv4Address"), + ParseError::InvalidIpv6Address => + serializer.serialize_unit_variant("ParseError", 4, "InvalidIpv6Address"), + ParseError::InvalidDomainCharacter => + serializer.serialize_unit_variant("ParseError", 5, "InvalidDomainCharacter"), + ParseError::RelativeUrlWithoutBase => + serializer.serialize_unit_variant("ParseError", 6, "RelativeUrlWithoutBase"), + ParseError::RelativeUrlWithCannotBeABaseBase => + serializer.serialize_unit_variant( + "ParseError", 7, "RelativeUrlWithCannotBeABaseBase"), + ParseError::SetHostOnCannotBeABaseUrl => + serializer.serialize_unit_variant("ParseError", 8, "SetHostOnCannotBeABaseUrl"), + ParseError::Overflow => serializer.serialize_unit_variant("ParseError", 9, "Overflow"), + } + } +} + impl<'a, String> Serialize for Ser<'a, Host> where String: AsRef { fn serialize(&self, serializer: S) -> Result where S: Serializer { match *self.0 { @@ -202,6 +231,141 @@ impl<'de> Deserialize<'de> for De> { } } +/// Deserializes this ParseError from a `serde` stream. +impl<'de> Deserialize<'de> for De { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + enum Variant { + EmptyHost, + IdnaError, + InvalidPort, + InvalidIpv4Address, + InvalidIpv6Address, + InvalidDomainCharacter, + RelativeUrlWithoutBase, + RelativeUrlWithCannotBeABaseBase, + SetHostOnCannotBeABaseUrl, + Overflow, + } + + struct VariantVisitor; + + impl<'de> Visitor<'de> for VariantVisitor { + type Value = Variant; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("variant identifier") + } + + fn visit_u32(self, v: u32) -> Result where E: de::Error { + match v { + 0 => Ok(Variant::EmptyHost), + 1 => Ok(Variant::IdnaError), + 2 => Ok(Variant::InvalidPort), + 3 => Ok(Variant::InvalidIpv4Address), + 4 => Ok(Variant::InvalidIpv6Address), + 5 => Ok(Variant::InvalidDomainCharacter), + 6 => Ok(Variant::RelativeUrlWithoutBase), + 7 => Ok(Variant::RelativeUrlWithCannotBeABaseBase), + 8 => Ok(Variant::SetHostOnCannotBeABaseUrl), + 9 => Ok(Variant::Overflow), + _ => Err(de::Error::invalid_value(Unexpected::Unsigned(v as u64), + &"variant index 0 <= i < 10")), + } + } + + fn visit_str(self, v: &str) -> Result where E: de::Error { + match v { + "EmptyHost" => Ok(Variant::EmptyHost), + "IdnaError" => Ok(Variant::IdnaError), + "InvalidPort" => Ok(Variant::InvalidPort), + "InvalidIpv4Address" => Ok(Variant::InvalidIpv4Address), + "InvalidIpv6Address" => Ok(Variant::InvalidIpv6Address), + "InvalidDomainCharacter" => Ok(Variant::InvalidDomainCharacter), + "RelativeUrlWithoutBase" => Ok(Variant::RelativeUrlWithoutBase), + "RelativeUrlWithCannotBeABaseBase" => Ok(Variant::RelativeUrlWithCannotBeABaseBase), + "SetHostOnCannotBeABaseUrl" => Ok(Variant::SetHostOnCannotBeABaseUrl), + "Overflow" => Ok(Variant::Overflow), + _ => Err(de::Error::unknown_variant(v, VARIANTS)), + } + } + + fn visit_bytes(self, v: &[u8]) -> Result where E: de::Error { + match v { + b"EmptyHost" => Ok(Variant::EmptyHost), + b"IdnaError" => Ok(Variant::IdnaError), + b"InvalidPort" => Ok(Variant::InvalidPort), + b"InvalidIpv4Address" => Ok(Variant::InvalidIpv4Address), + b"InvalidIpv6Address" => Ok(Variant::InvalidIpv6Address), + b"InvalidDomainCharacter" => Ok(Variant::InvalidDomainCharacter), + b"RelativeUrlWithoutBase" => Ok(Variant::RelativeUrlWithoutBase), + b"RelativeUrlWithCannotBeABaseBase" => Ok(Variant::RelativeUrlWithCannotBeABaseBase), + b"SetHostOnCannotBeABaseUrl" => Ok(Variant::SetHostOnCannotBeABaseUrl), + b"Overflow" => Ok(Variant::Overflow), + _ => { + let s = String::from_utf8_lossy(v); + Err(de::Error::unknown_variant(&s, VARIANTS)) + }, + } + } + } + + impl<'de> Deserialize<'de> for Variant { + fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { + deserializer.deserialize_identifier(VariantVisitor) + } + } + + struct ParseErrorVisitor; + + impl<'de> Visitor<'de> for ParseErrorVisitor { + type Value = ParseError; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("enum ParseError") + } + + fn visit_enum(self, data: A) -> Result where A: EnumAccess<'de> { + match data.variant()? { + (Variant::EmptyHost, variant) => + variant.unit_variant().map(|()| ParseError::EmptyHost), + (Variant::IdnaError, variant) => + variant.unit_variant().map(|()| ParseError::IdnaError), + (Variant::InvalidPort, variant) => + variant.unit_variant().map(|()| ParseError::InvalidPort), + (Variant::InvalidIpv4Address, variant) => + variant.unit_variant().map(|()| ParseError::InvalidIpv4Address), + (Variant::InvalidIpv6Address, variant) => + variant.unit_variant().map(|()| ParseError::InvalidIpv6Address), + (Variant::InvalidDomainCharacter, variant) => + variant.unit_variant().map(|()| ParseError::InvalidDomainCharacter), + (Variant::RelativeUrlWithoutBase, variant) => + variant.unit_variant().map(|()| ParseError::RelativeUrlWithoutBase), + (Variant::RelativeUrlWithCannotBeABaseBase, variant) => + variant.unit_variant().map(|()| ParseError::RelativeUrlWithCannotBeABaseBase), + (Variant::SetHostOnCannotBeABaseUrl, variant) => + variant.unit_variant().map(|()| ParseError::SetHostOnCannotBeABaseUrl), + (Variant::Overflow, variant) => + variant.unit_variant().map(|()| ParseError::Overflow), + } + } + } + + const VARIANTS: &'static [&'static str] = &[ + "EmptyHost", + "IdnaError", + "InvalidPort", + "InvalidIpv4Address", + "InvalidIpv6Address", + "InvalidDomainCharacter", + "RelativeUrlWithoutBase", + "RelativeUrlWithCannotBeABaseBase", + "SetHostOnCannotBeABaseUrl", + "Overflow" + ]; + deserializer.deserialize_enum("ParseError", VARIANTS, ParseErrorVisitor).map(De) + } +} + impl<'de> Deserialize<'de> for De { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de> { let string_representation: String = Deserialize::deserialize(deserializer)?; @@ -396,6 +560,14 @@ fn test_derive_with_for_url() { assert_eq!(json_string, got); } +#[test] +fn test_parse_error() { + let err = ParseError::EmptyHost; + let json = serde_json::to_string(&Ser(&err)).unwrap(); + let de: De = serde_json::from_str(&json).unwrap(); + assert_eq!(de.into_inner(), ParseError::EmptyHost); +} + #[test] fn test_host() { for host in &[