diff --git a/examples/get_route.rs b/examples/get_route.rs index fb402e5..c609f2f 100644 --- a/examples/get_route.rs +++ b/examples/get_route.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT -use futures::stream::TryStreamExt; +use futures::{Stream, TryStreamExt}; +use netlink_packet_route::RouteMessage; use rtnetlink::{new_connection, Error, Handle, IpVersion}; #[tokio::main] @@ -27,7 +28,12 @@ async fn dump_addresses( handle: Handle, ip_version: IpVersion, ) -> Result<(), Error> { - let mut routes = handle.route().get(ip_version).execute(); + let mut routes: Box< + dyn Stream> + Unpin, + > = match ip_version { + IpVersion::V4 => Box::new(handle.route().get().v4().execute()), + IpVersion::V6 => Box::new(handle.route().get().v6().execute()), + }; while let Some(route) = routes.try_next().await? { println!("{route:?}"); } diff --git a/examples/get_route_to.rs b/examples/get_route_to.rs new file mode 100644 index 0000000..cee4997 --- /dev/null +++ b/examples/get_route_to.rs @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT + +use std::net::IpAddr; + +use futures::{Stream, TryStreamExt}; +use netlink_packet_route::RouteMessage; +use rtnetlink::{new_connection, Error, Handle}; + +#[tokio::main] +async fn main() { + let (connection, handle, _) = new_connection().unwrap(); + tokio::spawn(connection); + + let destinations = [ + "8.8.8.8".parse().unwrap(), + "127.0.0.8".parse().unwrap(), + "2001:4860:4860::8888".parse().unwrap(), + "::1".parse().unwrap(), + ]; + for dest in destinations { + println!("getting best route to {}", dest); + if let Err(e) = dump_route_to(handle.clone(), dest).await { + eprintln!("{e}"); + } + println!(); + } +} + +async fn dump_route_to(handle: Handle, dest: IpAddr) -> Result<(), Error> { + let mut routes: Box< + dyn Stream> + Unpin, + > = match dest { + IpAddr::V4(v4) => Box::new(handle.route().get().v4().to(v4).execute()), + IpAddr::V6(v6) => Box::new(handle.route().get().v6().to(v6).execute()), + }; + if let Some(route) = routes.try_next().await? { + println!("{route:?}"); + } + Ok(()) +} diff --git a/src/route/get.rs b/src/route/get.rs index f3cdf47..a7a0a06 100644 --- a/src/route/get.rs +++ b/src/route/get.rs @@ -1,22 +1,34 @@ // SPDX-License-Identifier: MIT +use std::{ + marker::PhantomData, + net::{Ipv4Addr, Ipv6Addr}, +}; + use futures::{ future::{self, Either}, - stream::{StreamExt, TryStream}, - FutureExt, + stream::StreamExt, + FutureExt, Stream, }; use netlink_packet_core::{NetlinkMessage, NLM_F_DUMP, NLM_F_REQUEST}; use netlink_packet_route::{ - RouteMessage, RtnlMessage, AF_INET, AF_INET6, RTN_UNSPEC, RTPROT_UNSPEC, - RT_SCOPE_UNIVERSE, RT_TABLE_UNSPEC, + route::Nla, RouteMessage, RtnlMessage, AF_INET, AF_INET6, RTN_UNSPEC, + RTPROT_UNSPEC, RT_SCOPE_UNIVERSE, RT_TABLE_UNSPEC, }; use crate::{try_rtnl, Error, Handle}; -pub struct RouteGetRequest { +pub struct RouteGetRequest { handle: Handle, message: RouteMessage, + // There are two ways to retrieve routes: we can either dump them + // all and filter the result, or if we already know the destination + // of the route we're looking for, we can just retrieve + // that one. If `dump` is `true`, all the routes are fetched. + // Otherwise, only the best route to the destination is fetched. + dump: bool, + _phantom: PhantomData, } /// Internet Protocol (IP) version. @@ -37,10 +49,9 @@ impl IpVersion { } } -impl RouteGetRequest { - pub(crate) fn new(handle: Handle, ip_version: IpVersion) -> Self { +impl RouteGetRequest { + pub(crate) fn new(handle: Handle) -> Self { let mut message = RouteMessage::default(); - message.header.address_family = ip_version.family(); // As per rtnetlink(7) documentation, setting the following // fields to 0 gets us all the routes from all the tables @@ -58,21 +69,118 @@ impl RouteGetRequest { message.header.table = RT_TABLE_UNSPEC; message.header.protocol = RTPROT_UNSPEC; - RouteGetRequest { handle, message } + RouteGetRequest { + handle, + message, + dump: true, + _phantom: Default::default(), + } + } + + /// Sets the output interface index. + pub fn output_interface(mut self, index: u32) -> Self { + self.message.nlas.push(Nla::Oif(index)); + self } pub fn message_mut(&mut self) -> &mut RouteMessage { &mut self.message } +} + +impl RouteGetRequest<()> { + pub fn v4(mut self) -> RouteGetRequest { + self.message.header.address_family = AF_INET as u8; + RouteGetRequest:: { + _phantom: PhantomData::, + handle: self.handle, + message: self.message, + dump: self.dump, + } + } + + pub fn v6(mut self) -> RouteGetRequest { + self.message.header.address_family = AF_INET6 as u8; + RouteGetRequest:: { + _phantom: PhantomData::, + handle: self.handle, + message: self.message, + dump: self.dump, + } + } +} + +impl RouteGetRequest { + /// Get the best route to this destination + pub fn to(mut self, ip: Ipv4Addr) -> Self { + self.message.nlas.push(Nla::Destination(ip.octets().into())); + self.message.header.destination_prefix_length = 32; + self.dump = false; + self + } + + pub fn from(mut self, ip: Ipv6Addr) -> Self { + self.message.nlas.push(Nla::Source(ip.octets().into())); + self.message.header.source_prefix_length = 32; + self + } + + pub fn execute(self) -> impl Stream> { + let RouteGetRequest { + mut handle, + message, + dump, + _phantom, + } = self; + + let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message)); + req.header.flags = if dump { + NLM_F_REQUEST | NLM_F_DUMP + } else { + NLM_F_REQUEST + }; + + match handle.request(req) { + Ok(response) => Either::Left( + response + .map(move |msg| Ok(try_rtnl!(msg, RtnlMessage::NewRoute))), + ), + Err(e) => Either::Right( + future::err::(e).into_stream(), + ), + } + } +} + +impl RouteGetRequest { + /// Get the best route to this destination + pub fn to(mut self, ip: Ipv6Addr) -> Self { + self.message.nlas.push(Nla::Destination(ip.octets().into())); + self.message.header.destination_prefix_length = 32; + self.dump = false; + self + } + + pub fn from(mut self, ip: Ipv6Addr) -> Self { + self.message.nlas.push(Nla::Source(ip.octets().into())); + self.message.header.source_prefix_length = 32; + self + } - pub fn execute(self) -> impl TryStream { + pub fn execute(self) -> impl Stream> { let RouteGetRequest { mut handle, message, + dump, + _phantom, } = self; let mut req = NetlinkMessage::from(RtnlMessage::GetRoute(message)); - req.header.flags = NLM_F_REQUEST | NLM_F_DUMP; + req.header.flags = if dump { + NLM_F_REQUEST | NLM_F_DUMP + } else { + NLM_F_REQUEST + }; match handle.request(req) { Ok(response) => Either::Left( diff --git a/src/route/handle.rs b/src/route/handle.rs index 2b9488e..ff2a869 100644 --- a/src/route/handle.rs +++ b/src/route/handle.rs @@ -1,8 +1,6 @@ // SPDX-License-Identifier: MIT -use crate::{ - Handle, IpVersion, RouteAddRequest, RouteDelRequest, RouteGetRequest, -}; +use crate::{Handle, RouteAddRequest, RouteDelRequest, RouteGetRequest}; use netlink_packet_route::RouteMessage; pub struct RouteHandle(Handle); @@ -14,8 +12,8 @@ impl RouteHandle { /// Retrieve the list of routing table entries (equivalent to `ip route /// show`) - pub fn get(&self, ip_version: IpVersion) -> RouteGetRequest { - RouteGetRequest::new(self.0.clone(), ip_version) + pub fn get(&self) -> RouteGetRequest { + RouteGetRequest::new(self.0.clone()) } /// Add an routing table entry (equivalent to `ip route add`)