diff --git a/ingot-types/src/ip.rs b/ingot-types/src/ip.rs index b8caaa2..a99a62e 100644 --- a/ingot-types/src/ip.rs +++ b/ingot-types/src/ip.rs @@ -22,7 +22,7 @@ impl Ipv4Addr { /// Return the bytes of the address. #[inline] - pub fn octets(&self) -> [u8; 4] { + pub const fn octets(&self) -> [u8; 4] { self.inner } @@ -31,6 +31,102 @@ impl Ipv4Addr { pub const fn from_octets(bytes: [u8; 4]) -> Self { Self { inner: bytes } } + + /// Private function to convert to a `core::net::Ipv4Addr` + /// in a const context as `From` implementations are not + /// allowed in const contexts. + /// + /// This can be simplied once [`from_octets` and `from_segements`] is + /// stabilized. + /// + /// [`from_octets` and `from_segements`]: https://github.com/rust-lang/rust/issues/131360 + #[inline] + const fn into_core(self) -> core::net::Ipv4Addr { + core::net::Ipv4Addr::new( + self.inner[0], + self.inner[1], + self.inner[2], + self.inner[3], + ) + } + + /// Returns true if the address is a multicast address. + #[inline] + pub const fn is_multicast(&self) -> bool { + self.into_core().is_multicast() + } + + /// Returns true if the address is a local broadcast address. + #[inline] + pub const fn is_broadcast(&self) -> bool { + self.into_core().is_broadcast() + } + + /// Returns true if the address is a private address. + #[inline] + pub const fn is_private(&self) -> bool { + self.into_core().is_private() + } + + /// Returns true if the address is a loopback address. + #[inline] + pub const fn is_loopback(&self) -> bool { + self.into_core().is_loopback() + } + + /// Returns true if the address is a unicast address. + #[inline] + pub const fn is_unicast(&self) -> bool { + !self.is_multicast() && !self.is_broadcast() + } + + /// Returns true if the address is a link-local address. + #[inline] + pub const fn is_link_local(&self) -> bool { + self.into_core().is_link_local() + } + + /// Returns true if the address is a global unicast address. + #[inline] + pub const fn is_global(&self) -> bool { + !self.is_multicast() + && !self.is_private() + && !self.is_loopback() + && !self.is_link_local() + && !self.is_broadcast() + } + + /// Returns true if the address is a documentation address. + /// There are three such unicast ranges [IETF RFC 5737]: + /// * 192.0.2.0/24 + /// * 198.51.100.0/24 + /// * 203.0.113.0/24 + /// + /// And one multicast ([IETF RFC 5771] / [IETF RFC 6676]) one: + /// * 233.252.0.0/24 + /// + /// [IETF RFC 5737]: https://tools.ietf.org/html/rfc5737 + /// [IETF RFC 5771]: https://tools.ietf.org/html/rfc5771 + /// [IETF RFC 6676]: https://tools.ietf.org/html/rfc6676 + #[inline] + pub const fn is_documentation(&self) -> bool { + matches!( + self.octets(), + [192, 0, 2, _] + | [198, 51, 100, _] + | [203, 0, 113, _] + | [233, 252, 0, _] + ) + } + + /// Returns true if the address is a reserved address. + /// + /// Note: The underlying `core::net` version is not yet stable as + /// of Rust 1.84.1. + #[inline] + pub const fn is_reserved(&self) -> bool { + self.octets()[0] & 240 == 240 && !self.is_broadcast() + } } impl From for Ipv4Addr { @@ -64,7 +160,7 @@ impl Ipv6Addr { /// Return the bytes of the address. #[inline] - pub fn octets(&self) -> [u8; 16] { + pub const fn octets(&self) -> [u8; 16] { self.inner } @@ -92,6 +188,109 @@ impl Ipv6Addr { ], } } + + /// Returns an eight element 16-bit array representation of the address. + /// + /// This is taken from the core `Ipv6Addr` implementation. + #[inline] + pub const fn segments(&self) -> [u16; 8] { + // All elements in `self.octets` must be big endian. + // SAFETY: `[u8; 16]` is always safe to transmute to `[u16; 8]`. + let [a, b, c, d, e, f, g, h] = unsafe { + core::mem::transmute::<[u8; 16], [u16; 8]>(self.octets()) + }; + // We want native endian u16 + [ + u16::from_be(a), + u16::from_be(b), + u16::from_be(c), + u16::from_be(d), + u16::from_be(e), + u16::from_be(f), + u16::from_be(g), + u16::from_be(h), + ] + } + + /// Private function to convert to a `core::net::Ipv6Addr` + /// in a const context as `From` implementations are not + /// yet allowed in const contexts. + /// This can be simplied once [`from_octets` and `from_segements`] is + /// stabilized. + /// + /// [`from_octets` and `from_segements`]: https://github.com/rust-lang/rust/issues/131360 + #[inline] + const fn into_core(self) -> core::net::Ipv6Addr { + let segments = self.segments(); + core::net::Ipv6Addr::new( + segments[0], + segments[1], + segments[2], + segments[3], + segments[4], + segments[5], + segments[6], + segments[7], + ) + } + + /// Returns true if the address is a multicast address. + #[inline] + pub const fn is_multicast(&self) -> bool { + self.into_core().is_multicast() + } + + /// Returns true if the address is a loopback address. + #[inline] + pub const fn is_loopback(&self) -> bool { + self.into_core().is_loopback() + } + + /// Returns true if the address is a unicast address. + #[inline] + pub const fn is_unicast(&self) -> bool { + !self.is_multicast() + } + + /// Returns true if the address is a unicast link-local address. + /// + /// Note: The underlying `core::net` version is not yet stable as + /// of Rust 1.84.1. + #[inline] + pub const fn is_unicast_link_local(&self) -> bool { + (self.segments()[0] & 0xffc0) == 0xfe80 + } + + /// Returns true if the address is a unique local address. + /// + /// Note: The underlying `core::net` version is not yet stable as + /// of Rust 1.84.1. + #[inline] + pub const fn is_unique_local(&self) -> bool { + (self.segments()[0] & 0xfe00) == 0xfc00 + } + + /// Returns true if the address is a global unicast address. + #[inline] + pub const fn is_unicast_global(&self) -> bool { + !self.is_multicast() + && !self.is_unicast_link_local() + && !self.is_unique_local() + } + + /// Returns true if the address is a documentation address. + /// + /// Defined in [IETF RFC 3849]. + /// + /// Note: The underlying `core::net` version is not yet stable as + /// of Rust 1.84.1. + /// + /// [IETF RFC 3849]: https://tools.ietf.org/html/rfc3849 + #[inline] + pub const fn is_documentation(&self) -> bool { + let segments = self.segments(); + (segments[0] == 0x2001) && (segments[1] == 0xdb8) + } } impl From for Ipv6Addr { @@ -107,3 +306,76 @@ impl From for core::net::Ipv6Addr { Self::from(ip6.inner) } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn ipv4() { + let addr = Ipv4Addr::from_octets([192, 168, 1, 1]); + assert!(addr.is_private()); + assert!(!addr.is_global()); + assert!(!addr.is_multicast()); + assert!(!addr.is_broadcast()); + assert!(!addr.is_loopback()); + assert!(addr.is_unicast()); + assert!(!addr.is_link_local()); + assert!(!addr.is_documentation()); + assert!(!addr.is_reserved()); + } + + #[test] + fn ipv4_broadcast() { + let addr = Ipv4Addr::from_octets([255, 255, 255, 255]); + assert!(!addr.is_private()); + assert!(!addr.is_global()); + assert!(!addr.is_multicast()); + assert!(addr.is_broadcast()); + assert!(!addr.is_unicast()); + assert!(!addr.is_loopback()); + assert!(!addr.is_link_local()); + assert!(!addr.is_documentation()); + assert!(!addr.is_reserved()); + } + + #[test] + fn ipv4_loopback() { + let addr = Ipv4Addr::from_octets([127, 0, 0, 1]); + assert!(!addr.is_private()); + assert!(!addr.is_global()); + assert!(!addr.is_multicast()); + assert!(!addr.is_broadcast()); + assert!(addr.is_loopback()); + assert!(addr.is_unicast()); + assert!(!addr.is_link_local()); + assert!(!addr.is_documentation()); + assert!(!addr.is_reserved()); + } + + #[test] + fn ipv6() { + let addr = Ipv6Addr::from_octets([ + 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ]); + assert!(!addr.is_multicast()); + assert!(addr.is_unicast()); + assert!(!addr.is_unicast_link_local()); + assert!(!addr.is_unique_local()); + assert!(addr.is_documentation()); + assert!(addr.is_unicast_global()); + } + + #[test] + fn ipv6_link_local() { + let addr = Ipv6Addr::from_octets([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xde, 0xad, 0xbe, 0xef, + ]); + assert!(!addr.is_multicast()); + assert!(addr.is_unicast()); + assert!(addr.is_unicast_link_local()); + assert!(!addr.is_unique_local()); + assert!(!addr.is_documentation()); + assert!(!addr.is_unicast_global()); + } +} diff --git a/ingot/src/igmp.rs b/ingot/src/igmp.rs index 30221f6..47de395 100644 --- a/ingot/src/igmp.rs +++ b/ingot/src/igmp.rs @@ -136,7 +136,36 @@ pub struct IgmpV2LeaveGroup { #[allow(clippy::unusual_byte_groupings)] mod test { use super::*; - use crate::types::{Header, HeaderParse}; + use crate::types::{Emit, Header, HeaderParse}; + use ingot_types::HeaderLen; + + impl IgmpV3GroupRecord { + fn bytes_len(&self) -> usize { + 8 + self.source_addrs.len() * 4 + self.auxiliary_data.len() + } + } + + impl IgmpV3MembershipReport { + fn bytes_len(&self) -> usize { + 8 + self.group_records.iter().map(|r| r.bytes_len()).sum::() + } + } + + fn compute_checksum(bytes: &[u8]) -> u16 { + let mut sum: u32 = 0; + for chunk in bytes.chunks(2) { + let word = match chunk { + [a, b] => ((*a as u16) << 8) | (*b as u16), + [a] => (*a as u16) << 8, + _ => unreachable!(), + }; + sum = sum.wrapping_add(word as u32); + } + while (sum >> 16) != 0 { + sum = (sum & 0xFFFF) + (sum >> 16); + } + !sum as u16 + } #[test] fn parse() { @@ -253,4 +282,261 @@ mod test { let (_r, _, rest) = ValidIgmpMembershipQuery::parse(bytes).unwrap(); assert_eq!(rest, &[1, 2, 3, 4, 5]); } + + #[test] + fn generate_v3_membership_report() { + let igmp_default = IgmpV3MembershipReport { + ty: IgmpMessageType::V3_MEMBERSHIP_REPORT, + ..Default::default() + }; + let bytes = igmp_default.emit_vec(); + assert_eq!(bytes.len(), igmp_default.bytes_len()); + + let igmp_with_group_addrs = IgmpV3MembershipReport { + ty: IgmpMessageType::V3_MEMBERSHIP_REPORT, + group_records: Repeated::new(vec![ + IgmpV3GroupRecord { + record_type: IgmpV3RecordType::MODE_IS_INCLUDE, + aux_data_len: 0, + num_sources: 0, + multicast_addr: Ipv4Addr::from_octets([239, 1, 2, 3]), + source_addrs: Vec::new(), + auxiliary_data: Vec::new(), + }, + IgmpV3GroupRecord { + record_type: IgmpV3RecordType::MODE_IS_EXCLUDE, + aux_data_len: 0, + num_sources: 0, + multicast_addr: Ipv4Addr::from_octets([239, 1, 2, 4]), + source_addrs: Vec::new(), + auxiliary_data: Vec::new(), + }, + ]), + ..Default::default() + }; + + for record in igmp_with_group_addrs.group_records.iter() { + assert_eq!(record.bytes_len(), 8); + assert!(record.multicast_addr.is_multicast()); + } + + let bytes = igmp_with_group_addrs.emit_vec(); + assert_eq!(bytes.len(), igmp_with_group_addrs.bytes_len()); + } + + #[test] + fn generate_v3_membership_queries() { + // Test case with specific resv/s/qrv values + let query_with_resv = IgmpMembershipQuery { + ty: IgmpMessageType::MEMBERSHIP_QUERY, + max_resp: 100, + group_address: Ipv4Addr::from_octets([239, 1, 2, 3]), + s: 1, // Suppress router-side processing + qrv: 2, // Robustness Variable + qqic: 125, + num_sources: 2, + source_addrs: vec![ + Ipv4Addr::from_octets([192, 168, 1, 10]), + Ipv4Addr::from_octets([192, 168, 1, 11]), + ], + ..Default::default() + }; + + let bytes = query_with_resv.emit_vec(); + // Check the byte containing resv(4 bits)|s(1 bit)|qrv(3 bits) + assert_eq!(bytes[8] & 0xf0, 0); // Upper 4 bits (resv) should be zero + assert_eq!(bytes[8] & 0x08, 0x08); // Next bit (s) should be 1 + assert_eq!(bytes[8] & 0x07, 0x02); // Last 3 bits (qrv) should be 2 + + // Test with different s/qrv combinations + let query_variations = [ + (0, 7, 0x07), // s=0, qrv=7 -> 0000_0111 + (1, 7, 0x0f), // s=1, qrv=7 -> 0000_1111 + (0, 0, 0x00), // s=0, qrv=0 -> 0000_0000 + (1, 0, 0x08), // s=1, qrv=0 -> 0000_1000 + ]; + + for (s, qrv, expected) in query_variations { + let query = IgmpMembershipQuery { + ty: IgmpMessageType::MEMBERSHIP_QUERY, + max_resp: 100, + group_address: Ipv4Addr::from_octets([239, 1, 2, 3]), + s, + qrv, + qqic: 125, + num_sources: 0, + source_addrs: vec![], + ..Default::default() + }; + + let bytes = query.emit_vec(); + assert_eq!( + bytes[8], expected, + "Failed for s={}, qrv={}, got {:08b}, expected {:08b}", + s, qrv, bytes[8], expected + ); + } + } + + #[test] + fn generate_v2_membership_report() { + let report = IgmpV2MembershipReport { + ty: IgmpMessageType::V2_MEMBERSHIP_REPORT, + max_resp: 0, // Should be zero in transmission + checksum: 0, // Will be computed by higher layer + group_address: Ipv4Addr::from_octets([239, 1, 2, 3]), // Valid multicast address + }; + + let bytes = report.emit_vec(); + assert_eq!(bytes.len(), 8); // V2 messages are fixed size + assert_eq!(bytes[0], 0x16); // V2 report type + assert_eq!(bytes[4..8], [239, 1, 2, 3]); // Group address + } + + #[test] + fn generate_invalid_mcast_v2_membership_report() { + let report = IgmpV2MembershipReport { + ty: IgmpMessageType::V2_MEMBERSHIP_REPORT, + group_address: Ipv4Addr::from_octets([192, 0, 0, 1]), // Invalid multicast address + ..Default::default() + }; + + assert!(!report.group_address.is_multicast()); + } + + #[test] + fn generate_v2_leave_group() { + let leave = IgmpV2LeaveGroup { + ty: IgmpMessageType::V2_LEAVE_GROUP, + group_address: Ipv4Addr::from_octets([239, 1, 2, 3]), + ..Default::default() + }; + + let bytes = leave.emit_vec(); + assert_eq!(bytes.len(), 8); // V2 messages are fixed size + assert_eq!(bytes[0], 0x17); // Leave group type + assert_eq!(bytes[4..8], [239, 1, 2, 3]); // Group address + } + + #[test] + fn test_max_size_messages() { + let max_sources = 100; // Use a reasonable maximum for testing + let max_query = IgmpMembershipQuery { + ty: IgmpMessageType::MEMBERSHIP_QUERY, + max_resp: 100, + checksum: 0, + group_address: Ipv4Addr::from_octets([224, 0, 0, 1]), + s: 1, + qrv: 2, + qqic: 125, + num_sources: max_sources, + source_addrs: vec![Ipv4Addr::UNSPECIFIED; max_sources as usize], + ..Default::default() + }; + + let bytes = max_query.emit_vec(); + assert_eq!(bytes.len(), 12 + (max_sources as usize * 4)); + } + + #[test] + fn test_round_trip_v3_membership_query() { + let original_query = IgmpMembershipQuery { + ty: IgmpMessageType::MEMBERSHIP_QUERY, + max_resp: 100, + group_address: Ipv4Addr::from_octets([224, 0, 0, 1]), + s: 1, + qrv: 2, + qqic: 125, + num_sources: 2, + source_addrs: vec![ + Ipv4Addr::from_octets([192, 168, 1, 10]), + Ipv4Addr::from_octets([192, 168, 1, 11]), + ], + ..Default::default() + }; + + let bytes = original_query.emit_vec(); + let (parsed_query, ..) = + ValidIgmpMembershipQuery::parse(&*bytes).unwrap(); + + assert_eq!(parsed_query.ty(), original_query.ty); + assert_eq!(parsed_query.max_resp(), original_query.max_resp); + assert_eq!(parsed_query.group_address(), original_query.group_address); + assert_eq!(parsed_query.qrv(), original_query.qrv); + assert_eq!(parsed_query.qqic(), original_query.qqic); + assert_eq!(parsed_query.num_sources(), original_query.num_sources); + } + + #[test] + fn test_round_trip_v3_membership_report() { + let original_report = IgmpV3MembershipReport { + ty: IgmpMessageType::V3_MEMBERSHIP_REPORT, + resv1: 0, + checksum: 0, + resv2: 0, + num_group_records: 2, + group_records: Repeated::new(vec![ + IgmpV3GroupRecord { + record_type: IgmpV3RecordType::MODE_IS_INCLUDE, + aux_data_len: 0, + num_sources: 0, + multicast_addr: Ipv4Addr::from_octets([1, 2, 3, 4]), + source_addrs: Vec::new(), + auxiliary_data: Vec::new(), + }, + IgmpV3GroupRecord { + record_type: IgmpV3RecordType::MODE_IS_EXCLUDE, + aux_data_len: 0, + num_sources: 0, + multicast_addr: Ipv4Addr::from_octets([5, 6, 7, 8]), + source_addrs: Vec::new(), + auxiliary_data: Vec::new(), + }, + ]), + }; + + let bytes = original_report.emit_vec(); + let (parsed_report, ..) = + ValidIgmpV3MembershipReport::parse(&*bytes).unwrap(); + + assert_eq!(parsed_report.ty(), original_report.ty); + assert_eq!(parsed_report.resv1(), original_report.resv1); + assert_eq!(parsed_report.checksum(), original_report.checksum); + assert_eq!(parsed_report.resv2(), original_report.resv2); + assert_eq!( + parsed_report.num_group_records(), + original_report.num_group_records + ); + + let original_records = original_report.group_records.iter(); + let original_records_plen = + original_records.map(|r| r.bytes_len()).sum::(); + let parsed_records_plen = + parsed_report.group_records_ref().packet_length(); + assert_eq!(original_records_plen, parsed_records_plen); + } + + #[test] + fn test_checksum_verification() { + let mut query = IgmpMembershipQuery { + ty: IgmpMessageType::MEMBERSHIP_QUERY, + max_resp: 100, + checksum: 0, // Initially zero + group_address: Ipv4Addr::from_octets([224, 0, 0, 1]), + s: 1, + qrv: 2, + qqic: 125, + num_sources: 1, + source_addrs: vec![Ipv4Addr::from_octets([192, 168, 1, 1])], + ..Default::default() + }; + + let mut bytes = query.emit_vec(); + let computed_checksum = compute_checksum(&bytes); + query.checksum = computed_checksum; + + // Re-emit with correct checksum + bytes = query.emit_vec(); + assert_eq!(compute_checksum(&bytes), 0); + } }