diff --git a/Cargo.toml b/Cargo.toml index 5eab716..a2d74fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,8 +27,7 @@ ipnet = "2.3.1" log = "0.4.14" rand = "0.8.4" bytes = "1.1.0" -thiserror = "1.0.29" -anyhow = "1.0.44" +thiserror = "1.0" [target.'cfg(not(windows))'.dependencies] nix = "0.23" diff --git a/src/buffer/buffer_test.rs b/src/buffer/buffer_test.rs index 7cdb7fb..a6c494c 100644 --- a/src/buffer/buffer_test.rs +++ b/src/buffer/buffer_test.rs @@ -1,5 +1,5 @@ -use super::error::Error; use super::*; +use crate::error::Error; use tokio::time::{sleep, Duration}; use tokio_test::assert_ok; @@ -21,7 +21,7 @@ async fn test_buffer() { // Read deadline let result = buffer.read(&mut packet, Some(Duration::new(0, 1))).await; assert!(result.is_err()); - assert!(Error::ErrTimeout.equal(&result.unwrap_err())); + assert_eq!(Error::ErrTimeout, result.unwrap_err()); // Write twice let n = assert_ok!(buffer.write(&[2, 3, 4]).await); @@ -58,7 +58,7 @@ async fn test_buffer() { // Until EOF let result = buffer.read(&mut packet, None).await; assert!(result.is_err()); - assert!(Error::ErrBufferClosed.equal(&result.unwrap_err())); + assert_eq!(Error::ErrBufferClosed, result.unwrap_err()); } async fn test_wraparound(grow: bool) { @@ -138,7 +138,7 @@ async fn test_buffer_async() { let result = buffer2.read(&mut packet, None).await; assert!(result.is_err()); - assert!(Error::ErrBufferClosed.equal(&result.unwrap_err())); + assert_eq!(Error::ErrBufferClosed, result.unwrap_err()); drop(done_tx); }); @@ -178,7 +178,7 @@ async fn test_buffer_limit_count() { let result = buffer.write(&[4, 5]).await; assert!(result.is_err()); if let Err(err) = result { - assert!(Error::ErrBufferFull.equal(&err)); + assert_eq!(Error::ErrBufferFull, err); } assert_eq!(2, buffer.count().await); @@ -198,7 +198,7 @@ async fn test_buffer_limit_count() { let result = buffer.write(&[8, 9]).await; assert!(result.is_err()); if let Err(err) = result { - assert!(Error::ErrBufferFull.equal(&err)); + assert_eq!(Error::ErrBufferFull, err); } assert_eq!(2, buffer.count().await); @@ -236,7 +236,7 @@ async fn test_buffer_limit_size() { let result = buffer.write(&[4, 5]).await; assert!(result.is_err()); if let Err(err) = result { - assert!(Error::ErrBufferFull.equal(&err)); + assert_eq!(Error::ErrBufferFull, err); } assert_eq!(8, buffer.size().await); @@ -261,7 +261,7 @@ async fn test_buffer_limit_size() { let result = buffer.write(&[9, 10]).await; assert!(result.is_err()); if let Err(err) = result { - assert!(Error::ErrBufferFull.equal(&err)); + assert_eq!(Error::ErrBufferFull, err); } assert_eq!(11, buffer.size().await); @@ -320,7 +320,7 @@ async fn test_buffer_limit_sizes() { // Next write is expected to be errored. let result = buffer.write(&pkt).await; assert!(result.is_err(), "{}", name); - assert!(Error::ErrBufferFull.equal(&result.unwrap_err()), "{}", name); + assert_eq!(Error::ErrBufferFull, result.unwrap_err(), "{}", name); let mut packet = vec![0; size]; for _ in 0..n_packets { @@ -343,7 +343,7 @@ async fn test_buffer_misc() { let result = buffer.read(&mut packet, None).await; assert!(result.is_err()); if let Err(err) = result { - assert!(Error::ErrBufferShort.equal(&err)); + assert_eq!(Error::ErrBufferShort, err); } // Close diff --git a/src/buffer/error.rs b/src/buffer/error.rs deleted file mode 100644 index 9cd11d1..0000000 --- a/src/buffer/error.rs +++ /dev/null @@ -1,25 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum Error { - #[error("buffer: full")] - ErrBufferFull, - #[error("buffer: closed")] - ErrBufferClosed, - #[error("buffer: short")] - ErrBufferShort, - #[error("packet too big")] - ErrPacketTooBig, - #[error("i/o timeout")] - ErrTimeout, - - #[allow(non_camel_case_types)] - #[error("{0}")] - new(String), -} - -impl Error { - pub fn equal(&self, err: &anyhow::Error) -> bool { - err.downcast_ref::().map_or(false, |e| e == self) - } -} diff --git a/src/buffer/mod.rs b/src/buffer/mod.rs index f980a17..c75e77f 100644 --- a/src/buffer/mod.rs +++ b/src/buffer/mod.rs @@ -1,11 +1,8 @@ #[cfg(test)] mod buffer_test; -pub mod error; +use crate::error::{Error, Result}; -use error::Error; - -use anyhow::Result; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; use tokio::time::{timeout, Duration}; diff --git a/src/conn/conn_udp_listener.rs b/src/conn/conn_udp_listener.rs index 9910e3d..4bee41f 100644 --- a/src/conn/conn_udp_listener.rs +++ b/src/conn/conn_udp_listener.rs @@ -1,8 +1,7 @@ -use super::error::Error; use super::*; - +use crate::error::Error; use crate::Buffer; -use anyhow::Result; + use core::sync::atomic::Ordering; use std::collections::HashMap; use std::future::Future; @@ -239,7 +238,7 @@ impl Conn for UdpConn { } async fn recv(&self, buf: &mut [u8]) -> Result { - self.buffer.read(buf, None).await + Ok(self.buffer.read(buf, None).await?) } async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> { diff --git a/src/conn/conn_udp_listener_test.rs b/src/conn/conn_udp_listener_test.rs index 741e33a..1443133 100644 --- a/src/conn/conn_udp_listener_test.rs +++ b/src/conn/conn_udp_listener_test.rs @@ -1,6 +1,6 @@ use super::conn_udp_listener::*; -use super::error::Error; use super::*; +use crate::error::{Error, Result}; use std::future::Future; use std::pin::Pin; @@ -41,7 +41,7 @@ async fn pipe() -> Result<( let result = String::from_utf8(buf[..n].to_vec())?; if handshake != result { - Err(Error::new(format!("errHandshakeFailed: {} != {}", handshake, result)).into()) + Err(Error::Other(format!("errHandshakeFailed: {} != {}", handshake, result)).into()) } else { Ok((listener, l_conn, d_conn)) } @@ -118,7 +118,7 @@ async fn test_listener_accept_filter() -> Result<()> { let (c, _raddr) = match listener2.accept().await { Ok((c, raddr)) => (c, raddr), Err(err) => { - assert!(Error::ErrClosedListener.equal(&err)); + assert_eq!(Error::ErrClosedListener, err); return Result::<()>::Ok(()); } }; @@ -198,10 +198,7 @@ async fn test_listener_concurrent() -> Result<()> { conn.close().await?; } Err(err) => { - assert!( - Error::ErrClosedListener.equal(&err) - || Error::ErrClosedListenerAcceptCh.equal(&err) - ); + assert!(Error::ErrClosedListener == err || Error::ErrClosedListenerAcceptCh == err); } } diff --git a/src/conn/error.rs b/src/conn/error.rs deleted file mode 100644 index ce26356..0000000 --- a/src/conn/error.rs +++ /dev/null @@ -1,21 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum Error { - #[error("udp: listener closed")] - ErrClosedListener, - #[error("udp: listen queue exceeded")] - ErrListenQueueExceeded, - #[error("udp: listener accept ch closed")] - ErrClosedListenerAcceptCh, - - #[allow(non_camel_case_types)] - #[error("{0}")] - new(String), -} - -impl Error { - pub fn equal(&self, err: &anyhow::Error) -> bool { - err.downcast_ref::().map_or(false, |e| e == self) - } -} diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 11d93fe..e831716 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -3,7 +3,6 @@ pub mod conn_disconnected_packet; pub mod conn_pipe; pub mod conn_udp; pub mod conn_udp_listener; -pub mod error; #[cfg(test)] mod conn_bridge_test; @@ -17,12 +16,13 @@ mod conn_test; #[cfg(test)] mod conn_udp_listener_test; -use anyhow::Result; use async_trait::async_trait; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::ToSocketAddrs; +use crate::error::Result; + #[async_trait] pub trait Conn { async fn connect(&self, addr: SocketAddr) -> Result<()>; diff --git a/src/vnet/error.rs b/src/error.rs similarity index 53% rename from src/vnet/error.rs rename to src/error.rs index 775bda1..5b75ee0 100644 --- a/src/vnet/error.rs +++ b/src/error.rs @@ -1,7 +1,32 @@ +#![allow(dead_code)] + +use std::io; +use std::net; +use std::num::ParseIntError; +use std::string::FromUtf8Error; use thiserror::Error; +pub type Result = std::result::Result; + #[derive(Error, Debug, PartialEq)] +#[non_exhaustive] pub enum Error { + #[error("buffer: full")] + ErrBufferFull, + #[error("buffer: closed")] + ErrBufferClosed, + #[error("buffer: short")] + ErrBufferShort, + #[error("packet too big")] + ErrPacketTooBig, + #[error("i/o timeout")] + ErrTimeout, + #[error("udp: listener closed")] + ErrClosedListener, + #[error("udp: listen queue exceeded")] + ErrListenQueueExceeded, + #[error("udp: listener accept ch closed")] + ErrClosedListenerAcceptCh, #[error("obs cannot be nil")] ErrObsCannotBeNil, #[error("se of closed network connection")] @@ -80,14 +105,70 @@ pub enum Error { ErrNoIpaddrEth0, #[error("Invalid mask")] ErrInvalidMask, - - #[allow(non_camel_case_types)] + #[error("parse ipnet: {0}")] + ParseIpnet(#[from] ipnet::AddrParseError), + #[error("parse ip: {0}")] + ParseIp(#[from] net::AddrParseError), + #[error("parse int: {0}")] + ParseInt(#[from] ParseIntError), + #[error("{0}")] + Io(#[source] IoError), + #[error("utf8: {0}")] + Utf8(#[from] FromUtf8Error), + #[error("{0}")] + Std(#[source] StdError), #[error("{0}")] - new(String), + Other(String), } impl Error { - pub fn equal(&self, err: &anyhow::Error) -> bool { - err.downcast_ref::().map_or(false, |e| e == self) + pub fn from_std(error: T) -> Self + where + T: std::error::Error + Send + Sync + 'static, + { + Error::Std(StdError(Box::new(error))) + } + + pub fn downcast_ref(&self) -> Option<&T> { + if let Error::Std(s) = self { + return s.0.downcast_ref(); + } + + None + } +} + +#[derive(Debug, Error)] +#[error("io error: {0}")] +pub struct IoError(#[from] pub io::Error); + +// Workaround for wanting PartialEq for io::Error. +impl PartialEq for IoError { + fn eq(&self, other: &Self) -> bool { + self.0.kind() == other.0.kind() + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::Io(IoError(e)) + } +} + +/// An escape hatch to preserve stack traces when we don't know the error. +/// +/// This crate exports some traits such as `Conn` and `Listener`. The trait functions +/// produce the local error `util::Error`. However when used in crates higher up the stack, +/// we are forced to handle errors that are local to that crate. For example we use +/// `Listener` the `dtls` crate and it needs to handle `dtls::Error`. +/// +/// By using `util::Error::from_std` we can preserve the underlying error (and stack trace!). +#[derive(Debug, Error)] +#[error("{0}")] +pub struct StdError(pub Box); + +impl PartialEq for StdError { + fn eq(&self, _: &Self) -> bool { + false } } diff --git a/src/lib.rs b/src/lib.rs index 9533295..9dd33b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ #![warn(rust_2018_idioms)] #![allow(dead_code)] -use anyhow::Result; use async_trait::async_trait; +use thiserror::Error; + +use std::io; #[cfg(feature = "vnet")] #[macro_use] @@ -15,7 +17,10 @@ extern crate bitflags; pub mod fixed_big_int; pub mod replay_detector; -/// KeyingMaterialExporter to extract keying material +/// KeyingMaterialExporter to extract keying material. +/// +/// This trait sits here to avoid getting a direct dependency between +/// the dtls and srtp crates. #[async_trait] pub trait KeyingMaterialExporter { async fn export_keying_material( @@ -23,7 +28,35 @@ pub trait KeyingMaterialExporter { label: &str, context: &[u8], length: usize, - ) -> Result>; + ) -> Result, KeyingMaterialExporterError>; +} + +/// Possible errors while exporting keying material. +/// +/// These errors might have been more logically kept in the dtls +/// crate, but that would have required a direct depdency between +/// srtp and dtls. +#[derive(Debug, Error, PartialEq)] +#[non_exhaustive] +pub enum KeyingMaterialExporterError { + #[error("tls handshake is in progress")] + HandshakeInProgress, + #[error("context is not supported for export_keying_material")] + ContextUnsupported, + #[error("export_keying_material can not be used with a reserved label")] + ReservedExportKeyingMaterial, + #[error("no cipher suite for export_keying_material")] + CipherSuiteUnset, + #[error("export_keying_material io: {0}")] + Io(#[source] error::IoError), + #[error("export_keying_material hash: {0}")] + Hash(String), +} + +impl From for KeyingMaterialExporterError { + fn from(e: io::Error) -> Self { + KeyingMaterialExporterError::Io(error::IoError(e)) + } } #[cfg(feature = "buffer")] @@ -49,3 +82,7 @@ pub use crate::conn::Conn; #[cfg(feature = "marshal")] pub use crate::marshal::{exact_size_buf::ExactSizeBuf, Marshal, MarshalSize, Unmarshal}; + +mod error; + +pub use error::Error; diff --git a/src/marshal/error.rs b/src/marshal/error.rs deleted file mode 100644 index 575c44f..0000000 --- a/src/marshal/error.rs +++ /dev/null @@ -1,14 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, PartialEq)] -pub enum Error { - #[allow(non_camel_case_types)] - #[error("{0}")] - new(String), -} - -impl Error { - pub fn equal(&self, err: &anyhow::Error) -> bool { - err.downcast_ref::().map_or(false, |e| e == self) - } -} diff --git a/src/marshal/mod.rs b/src/marshal/mod.rs index 3a9c6b9..4f568bf 100644 --- a/src/marshal/mod.rs +++ b/src/marshal/mod.rs @@ -1,9 +1,8 @@ -pub mod error; pub mod exact_size_buf; -use anyhow::Result; use bytes::{Buf, Bytes, BytesMut}; -use error::Error; + +use crate::error::{Error, Result}; pub trait MarshalSize { fn marshal_size(&self) -> usize; @@ -18,7 +17,7 @@ pub trait Marshal: MarshalSize { buf.resize(l, 0); let n = self.marshal_to(&mut buf)?; if n != l { - Err(Error::new(format!("marshal_to output size {}, but expect {}", n, l)).into()) + Err(Error::Other(format!("marshal_to output size {}, but expect {}", n, l)).into()) } else { Ok(buf.freeze()) } diff --git a/src/vnet/chunk.rs b/src/vnet/chunk.rs index 1e8f4e5..4868925 100644 --- a/src/vnet/chunk.rs +++ b/src/vnet/chunk.rs @@ -2,8 +2,8 @@ mod chunk_test; use super::net::*; +use crate::error::Result; -use anyhow::Result; use std::fmt; use std::net::{IpAddr, SocketAddr}; use std::ops::{BitAnd, BitOr}; diff --git a/src/vnet/chunk/chunk_test.rs b/src/vnet/chunk/chunk_test.rs index 891336e..20e6bf0 100644 --- a/src/vnet/chunk/chunk_test.rs +++ b/src/vnet/chunk/chunk_test.rs @@ -1,3 +1,5 @@ +use crate::error::Result; + use super::*; #[test] diff --git a/src/vnet/chunk_queue/chunk_queue_test.rs b/src/vnet/chunk_queue/chunk_queue_test.rs index 510098f..59c1265 100644 --- a/src/vnet/chunk_queue/chunk_queue_test.rs +++ b/src/vnet/chunk_queue/chunk_queue_test.rs @@ -1,6 +1,7 @@ +use crate::error::Result; + use super::*; -use anyhow::Result; use std::net::SocketAddr; use std::str::FromStr; diff --git a/src/vnet/conn.rs b/src/vnet/conn.rs index 6184299..beaa762 100644 --- a/src/vnet/conn.rs +++ b/src/vnet/conn.rs @@ -1,11 +1,10 @@ #[cfg(test)] mod conn_test; -use super::error::*; use crate::conn::Conn; +use crate::error::*; use crate::vnet::chunk::{Chunk, ChunkUdp}; -use anyhow::Result; use std::net::{IpAddr, SocketAddr}; use tokio::sync::{mpsc, Mutex}; diff --git a/src/vnet/conn/conn_test.rs b/src/vnet/conn/conn_test.rs index 769cbfc..35a464d 100644 --- a/src/vnet/conn/conn_test.rs +++ b/src/vnet/conn/conn_test.rs @@ -16,7 +16,9 @@ impl ConnObserver for DummyObserver { let read_ch_tx = self.read_ch_tx.lock().await; if let Some(tx) = &*read_ch_tx { - tx.send(Box::new(chunk)).await?; + tx.send(Box::new(chunk)) + .await + .map_err(|e| Error::Other(e.to_string()))?; } Ok(()) } diff --git a/src/vnet/conn_map.rs b/src/vnet/conn_map.rs index 41094fc..01c074c 100644 --- a/src/vnet/conn_map.rs +++ b/src/vnet/conn_map.rs @@ -1,11 +1,10 @@ #[cfg(test)] mod conn_map_test; -use super::error::*; +use crate::error::*; use crate::vnet::conn::UdpConn; use crate::Conn; -use anyhow::Result; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; diff --git a/src/vnet/interface.rs b/src/vnet/interface.rs index 5ccf466..a66ceef 100644 --- a/src/vnet/interface.rs +++ b/src/vnet/interface.rs @@ -1,5 +1,4 @@ -use super::error::*; -use anyhow::Result; +use crate::error::*; use ipnet::*; use std::net::SocketAddr; use std::str::FromStr; diff --git a/src/vnet/mod.rs b/src/vnet/mod.rs index bef6e92..41b7c45 100644 --- a/src/vnet/mod.rs +++ b/src/vnet/mod.rs @@ -2,7 +2,6 @@ pub mod chunk; pub(crate) mod chunk_queue; pub(crate) mod conn; pub(crate) mod conn_map; -pub mod error; pub mod interface; pub mod nat; pub mod net; diff --git a/src/vnet/nat.rs b/src/vnet/nat.rs index f5689ae..d1a3181 100644 --- a/src/vnet/nat.rs +++ b/src/vnet/nat.rs @@ -1,11 +1,10 @@ #[cfg(test)] mod nat_test; -use super::error::*; +use crate::error::*; use crate::vnet::chunk::Chunk; use crate::vnet::net::UDP_STR; -use anyhow::Result; use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::ops::Add; @@ -309,7 +308,7 @@ impl NetworkAddressTranslator { let dst_port = from.destination_addr().port(); to.set_destination_addr(&format!("{}:{}", dst_ip, dst_port))?; } else { - return Err(Error::new(format!( + return Err(Error::Other(format!( "drop {} as {:?}", from, Error::ErrNoAssociatedLocalAddress @@ -333,7 +332,7 @@ impl NetworkAddressTranslator { { let filters = m.filters.lock().await; if !filters.contains(&filter_key) { - return Err(Error::new(format!( + return Err(Error::Other(format!( "drop {} as the remote {} {:?}", from, filter_key, @@ -353,7 +352,7 @@ impl NetworkAddressTranslator { to.set_destination_addr(&m.local)?; } else { - return Err(Error::new(format!( + return Err(Error::Other(format!( "drop {} as {:?}", from, Error::ErrNoNatBindingFound diff --git a/src/vnet/net.rs b/src/vnet/net.rs index 0a82e39..35cc680 100644 --- a/src/vnet/net.rs +++ b/src/vnet/net.rs @@ -2,14 +2,13 @@ mod net_test; use super::conn_map::*; -use super::error::*; use super::interface::*; +use crate::error::*; use crate::vnet::chunk::Chunk; use crate::vnet::conn::{ConnObserver, UdpConn}; use crate::vnet::router::*; use crate::{conn, ifaces, Conn}; -use anyhow::Result; use async_trait::async_trait; use ipnet::IpNet; use std::collections::HashMap; @@ -318,7 +317,7 @@ impl VNet { if (use_ipv4 && remote_addr.is_ipv4()) || (!use_ipv4 && remote_addr.is_ipv6()) { Ok(remote_addr) } else { - Err(Error::new(format!( + Err(Error::Other(format!( "No available {} IP address found!", if use_ipv4 { "ipv4" } else { "ipv6" }, )) diff --git a/src/vnet/net/net_test.rs b/src/vnet/net/net_test.rs index 3383321..57f8ce5 100644 --- a/src/vnet/net/net_test.rs +++ b/src/vnet/net/net_test.rs @@ -628,7 +628,7 @@ async fn test_net_virtual_loopback2() -> Result<()> { Ok(()) } -async fn get_ipaddr(nic: &Arc>) -> Result { +async fn get_ipaddr(nic: &Arc>) -> Result { let n = nic.lock().await; let eth0 = n .get_interface("eth0") diff --git a/src/vnet/resolver.rs b/src/vnet/resolver.rs index b956c4c..059895a 100644 --- a/src/vnet/resolver.rs +++ b/src/vnet/resolver.rs @@ -1,9 +1,8 @@ #[cfg(test)] mod resolver_test; -use super::error::*; +use crate::error::*; -use anyhow::Result; use std::collections::HashMap; use std::future::Future; use std::net::IpAddr; diff --git a/src/vnet/router.rs b/src/vnet/router.rs index 54eab63..d4560fc 100644 --- a/src/vnet/router.rs +++ b/src/vnet/router.rs @@ -1,15 +1,14 @@ #[cfg(test)] mod router_test; +use crate::error::*; use crate::vnet::chunk::*; use crate::vnet::chunk_queue::*; -use crate::vnet::error::*; use crate::vnet::interface::*; use crate::vnet::nat::*; use crate::vnet::net::*; use crate::vnet::resolver::*; -use anyhow::Result; use async_trait::async_trait; use ipnet::*; use std::collections::HashMap; diff --git a/src/vnet/router/router_test.rs b/src/vnet/router/router_test.rs index 2a45b92..2afeb69 100644 --- a/src/vnet/router/router_test.rs +++ b/src/vnet/router/router_test.rs @@ -103,7 +103,7 @@ impl Nic for DummyNic { } } -async fn get_ipaddr(nic: &Arc>) -> Result { +async fn get_ipaddr(nic: &Arc>) -> Result { let n = nic.lock().await; let eth0 = n .get_interface("eth0")