diff --git a/examples/get_route_to.rs b/examples/get_route_to.rs new file mode 100644 index 0000000..8397d5c --- /dev/null +++ b/examples/get_route_to.rs @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT + +use futures::stream::TryStreamExt; +use rtnetlink::{new_connection, Error, Handle}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +#[tokio::main] +async fn main() -> Result<(), ()> { + let (connection, handle, _) = new_connection().unwrap(); + tokio::spawn(connection); + + println!("dumping specific destination route for IPv4"); + let dest_str = "127.0.0.1"; + let dest: Ipv4Addr = dest_str.parse().expect("Invalid IP address format"); + if let Err(e) = dump_route_to(handle.clone(), IpAddr::V4(dest)).await { + eprintln!("{e}"); + } + println!(); + + println!("dumping specific destination route for IPv6"); + let dest_str = "::1"; + let dest: Ipv6Addr = dest_str.parse().expect("Invalid IP address format"); + if let Err(e) = dump_route_to(handle.clone(), IpAddr::V6(dest)).await { + eprintln!("{e}"); + } + println!(); + + Ok(()) +} + +async fn dump_route_to( + handle: Handle, + destination: IpAddr, +) -> Result<(), Error> { + let mut routes = handle.route().get_to(destination).execute_to(); + while let Some(route) = routes.try_next().await? { + println!("{route:?}"); + } + Ok(()) +} diff --git a/src/route/get.rs b/src/route/get.rs index e870057..fe73961 100644 --- a/src/route/get.rs +++ b/src/route/get.rs @@ -8,10 +8,13 @@ use futures::{ use netlink_packet_core::{NetlinkMessage, NLM_F_DUMP, NLM_F_REQUEST}; use netlink_packet_route::{ - route::{RouteHeader, RouteMessage, RouteProtocol, RouteScope, RouteType}, - AddressFamily, RouteNetlinkMessage, + route::{RouteHeader, RouteMessage, RouteProtocol, RouteScope, + RouteType, RouteAddress, RouteAttribute}, + AddressFamily, RouteNetlinkMessage, }; +use std::net::IpAddr; + use crate::{try_rtnl, Error, Handle}; pub struct RouteGetRequest { @@ -37,6 +40,19 @@ impl IpVersion { } } +trait IpAddrExt { + fn version(&self) -> IpVersion; +} + +impl IpAddrExt for std::net::IpAddr { + fn version(&self) -> IpVersion { + match self { + std::net::IpAddr::V4(_) => IpVersion::V4, + std::net::IpAddr::V6(_) => IpVersion::V6, + } + } +} + impl RouteGetRequest { pub(crate) fn new(handle: Handle, ip_version: IpVersion) -> Self { let mut message = RouteMessage::default(); @@ -84,4 +100,45 @@ impl RouteGetRequest { ), } } + + pub(crate) fn new_to(handle: Handle, destination: IpAddr) -> Self { + let mut message = RouteMessage::default(); + message.header.address_family = destination.version().family(); + + message.header.source_prefix_length = 0; + message.header.scope = RouteScope::Universe; + message.header.kind = RouteType::Unspec; + + message.header.table = RouteHeader::RT_TABLE_UNSPEC; + message.header.protocol = RouteProtocol::Unspec; + + let addr = match destination { + IpAddr::V4(v4_addr) => RouteAddress::from(v4_addr), + IpAddr::V6(v6_addr) => RouteAddress::from(v6_addr), + }; + + message.attributes.push(RouteAttribute::Destination(addr)); + + RouteGetRequest { handle, message } + } + + pub fn execute_to(self) -> impl TryStream { + let RouteGetRequest { + mut handle, + message, + } = self; + + let mut req = + NetlinkMessage::from(RouteNetlinkMessage::GetRoute(message)); + req.header.flags = NLM_F_REQUEST; + + match handle.request(req) { + Ok(response) => Either::Left(response.map(move |msg| { + Ok(try_rtnl!(msg, RouteNetlinkMessage::NewRoute)) + })), + Err(e) => Either::Right( + future::err::(e).into_stream(), + ), + } + } } diff --git a/src/route/handle.rs b/src/route/handle.rs index e2d2116..dd2e84a 100644 --- a/src/route/handle.rs +++ b/src/route/handle.rs @@ -4,6 +4,7 @@ use crate::{ Handle, IpVersion, RouteAddRequest, RouteDelRequest, RouteGetRequest, }; use netlink_packet_route::route::RouteMessage; +use std::net::IpAddr; pub struct RouteHandle(Handle); @@ -27,4 +28,8 @@ impl RouteHandle { pub fn del(&self, route: RouteMessage) -> RouteDelRequest { RouteDelRequest::new(self.0.clone(), route) } + + pub fn get_to(&self, destination: IpAddr) -> RouteGetRequest { + RouteGetRequest::new_to(self.0.clone(), destination) + } }