Skip to content

Fix last handshake time panic #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions examples/demo_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::{Read, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::{Duration, SystemTime};

use ipnet::{Ipv4Net, Ipv6Net};
use log::*;
Expand Down Expand Up @@ -70,31 +70,32 @@ fn main() {
}
assert!(adapter.up());

// Go to http://demo.wireguard.com/ and see the bandwidth numbers change!
println!("Printing peer bandwidth statistics");
println!("Press enter to exit");
let done = Arc::new(AtomicBool::new(false));
let done2 = Arc::clone(&done);
let thread = std::thread::spawn(move || {
'outer: loop {
for _ in 0..10 {
if done2.load(Ordering::Relaxed) {
break 'outer;
}
std::thread::sleep(Duration::from_millis(100));
}
let stats = adapter.get_config();
for peer in stats.peers {
let handshake_age = Instant::now().duration_since(peer.last_handshake);
println!(
" {:?}, up: {}, down: {}, handsake: {}s ago",
peer.allowed_ips,
peer.tx_bytes,
peer.rx_bytes,
handshake_age.as_secs_f32()
);
let thread = std::thread::spawn(move || 'outer: loop {
let stats = adapter.get_config();
for peer in stats.peers {
let handshake_age = peer
.last_handshake
.map(|h| SystemTime::now().duration_since(h).unwrap_or_default());
let handshake_msg = match handshake_age {
Some(age) => format!("handshake performed {:.2}s ago", age.as_secs_f32()),
None => format!("no active handshake"),
};

println!(
" {:?}, {} bytes up, {} bytes down, {handshake_msg}",
peer.allowed_ips, peer.tx_bytes, peer.rx_bytes
);
}
for _ in 0..10 {
if done2.load(Ordering::Relaxed) {
break 'outer;
}
// Go to 163.172.161.0 in your browser to see bandwidth numbers here change
// because only traffic to that ip is routed through the interface
std::thread::sleep(Duration::from_millis(100));
}
});

Expand All @@ -116,7 +117,7 @@ fn get_demo_server_config(pub_key: &[u8]) -> Result<(Vec<u8>, Ipv4Addr, SocketAd
.collect();

let mut s: TcpStream = TcpStream::connect_timeout(
addrs.get(0).expect("Failed to resolve demo server DNS"),
addrs.first().expect("Failed to resolve demo server DNS"),
Duration::from_secs(5),
)
.expect("Failed to open connection to demo server");
Expand Down
66 changes: 31 additions & 35 deletions src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::util::{StructReader, UnsafeHandle};
use crate::wireguard_nt_raw;
use crate::WireGuardError;
use std::mem::{align_of, size_of};
use std::time::{Duration, Instant, SystemTime};
use std::time::{Duration, SystemTime};

use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::ptr;
Expand Down Expand Up @@ -263,7 +263,7 @@ impl Adapter {
// `align_of::<WIREGUARD_INTERFACE` is 8, WIREGUARD_PEER has no special alignment
// requirements, and writer is already aligned to hold `WIREGUARD_INTERFACE` structs,
// therefore we uphold the alignment requirements of `write`
let mut wg_peer: &mut WIREGUARD_PEER = unsafe { writer.write() };
let wg_peer: &mut WIREGUARD_PEER = unsafe { writer.write() };

wg_peer.Flags = {
let mut flags = PeerFlags::HAS_ENDPOINT;
Expand Down Expand Up @@ -304,7 +304,7 @@ impl Adapter {
for allowed_ip in &peer.allowed_ips {
// Safety:
// Same as above, `writer` is aligned because it was aligned before
let mut wg_allowed_ip: &mut WIREGUARD_ALLOWED_IP = unsafe { writer.write() };
let wg_allowed_ip: &mut WIREGUARD_ALLOWED_IP = unsafe { writer.write() };
match allowed_ip {
IpNet::V4(v4) => {
let addr = unsafe { std::mem::transmute(v4.addr().octets()) };
Expand Down Expand Up @@ -357,7 +357,7 @@ impl Adapter {
use winapi::shared::winerror::{ERROR_OBJECT_ALREADY_EXISTS, ERROR_SUCCESS};
use winapi::shared::ws2def::{AF_INET, AF_INET6};

for allowed_ip in config.peers.iter().map(|p| p.allowed_ips.iter()).flatten() {
for allowed_ip in config.peers.iter().flat_map(|p| p.allowed_ips.iter()) {
use winapi::shared::netioapi::{InitializeIpForwardEntry, MIB_IPFORWARD_ROW2};
let mut default_route: MIB_IPFORWARD_ROW2 = std::mem::zeroed();
InitializeIpForwardEntry(&mut default_route);
Expand Down Expand Up @@ -464,12 +464,10 @@ impl Adapter {
/// Returns the adapter's LUID.
/// This is a 64bit unique identifier that windows uses when referencing this adapter
pub fn get_luid(&self) -> u64 {
let mut x = 0u64;
unsafe {
self.wireguard
.WireGuardGetAdapterLUID(self.adapter.0, std::mem::transmute(&mut x))
};
x
let mut luid = 0u64;
let ptr = &mut luid as *mut u64 as *mut wireguard_nt_raw::_NET_LUID_LH;
unsafe { self.wireguard.WireGuardGetAdapterLUID(self.adapter.0, ptr) };
luid
}

/// Sets the logging level of this adapter
Expand Down Expand Up @@ -500,15 +498,20 @@ impl Adapter {
&mut size as _,
)
};
assert_eq!(res, 0);
assert_eq!(unsafe { GetLastError() }, ERROR_MORE_DATA);
assert_ne!(size, 0); // size has been updated
// Should never fail since we
assert_eq!(res, 0, "Failed to query size of wireguard configuration");
assert_eq!(
unsafe { GetLastError() },
ERROR_MORE_DATA,
"WireGuardGetConfiguration returned invalid error for size request"
);
assert_ne!(size, 0, "Wireguard config is zero bytes"); // size has been updated
let align = align_of::<WIREGUARD_INTERFACE>();
let mut reader = StructReader::new(size as usize, align);
let res = unsafe {
self.wireguard.WireGuardGetConfiguration(
self.adapter.0,
reader.ptr() as _,
reader.ptr_mut().cast(),
&mut size as _,
)
};
Expand All @@ -520,7 +523,7 @@ impl Adapter {
// 3. We calculate the size of `reader` with the first call to `WireGuardGetConfiguration`. Wireguard writes at
// least one `WIREGUARD_INTERFACE`, and size is updated accordingly, therefore `reader`'s allocation is at least
// the size of a `WIREGUARD_INTERFACE`
let wireguard_interface: WIREGUARD_INTERFACE = unsafe { reader.read() };
let wireguard_interface: &WIREGUARD_INTERFACE = unsafe { reader.read() };
let mut wg_interface = WireguardInterface {
flags: wireguard_interface.Flags as u32,
listen_port: wireguard_interface.ListenPort,
Expand All @@ -529,21 +532,11 @@ impl Adapter {
peers: Vec::with_capacity(wireguard_interface.PeersCount as usize),
};

let now = SystemTime::now();
let now_instant = Instant::now();
let unix_duration = now
.duration_since(SystemTime::UNIX_EPOCH)
.expect("Time set before unix epoch");

// The number of 100ns intervals between 1-1-1600 and 1-1-1970
const UNIX_EPOCH_FROM_1_1_1600: u64 = 116444736000000000;
//calculate now based on the number of 100ns intervals since 1-1-1600
let now_since_1600 = UNIX_EPOCH_FROM_1_1_1600 + (unix_duration.as_nanos() / 100u128) as u64;
for _ in 0..wireguard_interface.PeersCount {
// # Safety:
// 1. `WireGuardGetConfiguration` writes a `WIREGUARD_PEER` immediately after the WIREGUARD_INTERFACE we read above.
// 2. We rely on Wireguard-NT to specify the number of peers written, and therefore we never read too many times unless Wireguard-NT (wrongly) tells us to
let peer: WIREGUARD_PEER = unsafe { reader.read() };
let peer: &WIREGUARD_PEER = unsafe { reader.read() };
let endpoint = peer.Endpoint;
let address_family = unsafe { endpoint.si_family } as i32;
let endpoint = match address_family {
Expand All @@ -568,12 +561,15 @@ impl Adapter {
panic!("Illegal address family {}", address_family);
}
};

//Calculate the difference in 100ns steps between the last handshake and now
let handshake_delta = now_since_1600 - peer.LastHandshake;

//The time of the lash handshake is now - the delta
let last_handshake = now_instant - Duration::from_nanos(handshake_delta * 100);
let last_handshake = if peer.LastHandshake == 0 {
None
} else {
// The number of 100ns intervals between 1-1-1600 and 1-1-1970
const UNIX_EPOCH_FROM_1_1_1600: u64 = 116444736000000000;
let ns_from_unix_epoch =
peer.LastHandshake.saturating_sub(UNIX_EPOCH_FROM_1_1_1600) * 100;
Some(SystemTime::UNIX_EPOCH + Duration::from_nanos(ns_from_unix_epoch))
};

let mut wg_peer = WireguardPeer {
flags: peer.Flags as u32,
Expand All @@ -590,7 +586,7 @@ impl Adapter {
// # Safety:
// 1. `WireGuardGetConfiguration` writes zero or more `WIREGUARD_ALLOWED_IP`s immediately after the WIREGUARD_PEER we read above.
// 2. We rely on Wireguard-NT to specify the number of allowed ips written, and therefore we never read too many times unless Wireguard-NT (wrongly) tells us to
let allowed_ip: WIREGUARD_ALLOWED_IP = unsafe { reader.read() };
let allowed_ip: &WIREGUARD_ALLOWED_IP = unsafe { reader.read() };
let prefix_length = allowed_ip.Cidr;
let allowed_ip = match allowed_ip.AddressFamily as i32 {
winapi::shared::ws2def::AF_INET => {
Expand Down Expand Up @@ -632,8 +628,8 @@ pub struct WireguardPeer {
pub tx_bytes: u64,
/// Number of bytes received
pub rx_bytes: u64,
/// Time of the last handshake
pub last_handshake: Instant,
/// Time of the last handshake, `None` if no handshake has occured
pub last_handshake: Option<SystemTime>,
/// Number of allowed IP structs following this struct
pub allowed_ips: Vec<IpNet>,
}
Expand Down
19 changes: 11 additions & 8 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use std::{alloc::Layout, sync::Arc};
/// A wrapper struct that allows a type to be Send and Sync
pub(crate) struct UnsafeHandle<T>(pub T);

/// We never read from the pointer. It only serves as a handle we pass to the kernel or C code that
/// doesn't have the same mutable aliasing restrictions we have in Rust
/// We never read from the pointer. It only serves as a handle we pass to the kernel or C code
/// (where locks are used internally)
unsafe impl<T> Send for UnsafeHandle<T> {}
unsafe impl<T> Sync for UnsafeHandle<T> {}

Expand Down Expand Up @@ -43,12 +43,14 @@ impl StructWriter {
/// Returns a reference of the desired type, which can be used to write a T into the
/// buffer at the internal pointer. The internal pointer will be advanced by `size_of::<T>()` so that
/// the next call to [`write`] will return a reference to an adjacent memory location.
/// The returned refrence will be the zero bit pattern initially.
///
/// # Safety:
/// The caller must ensure the internal pointer is aligned suitably for writing to a T.
/// 1. The caller must ensure the internal pointer is aligned suitably for writing to a T.
/// In most C APIs (like Wireguard NT) the structs are setup in such a way that calling write
/// repeatedly to pack data into the buffer always yields a struct that is aligned because the
/// previous struct was aligned.
/// 2. The caller must ensure that the zero bit pattern is valid for type T
///
/// # Panics
/// 1. If writing a struct of size T would overflow the buffer.
Expand Down Expand Up @@ -123,7 +125,7 @@ impl StructReader {
/// # Panics
/// 1. If reading a struct of size T would overflow the buffer.
/// 2. If the internal pointer does not meet the alignment requirements of T.
pub unsafe fn read<T>(&mut self) -> T {
pub unsafe fn read<T>(&mut self) -> &T {
let size = std::mem::size_of::<T>();
if size + self.offset > self.layout.size() {
panic!(
Expand All @@ -139,14 +141,15 @@ impl StructReader {
self.offset += size;
assert_eq!(ptr as usize % std::mem::align_of::<T>(), 0);

std::ptr::read(ptr as _)
unsafe { &*ptr.cast::<T>() }
}

pub fn ptr(&self) -> *const u8 {
pub fn ptr_mut(&self) -> *mut u8 {
self.start
}

/// Returns true if this reader's capacity is full, false otherwise
#[allow(dead_code)]
pub fn is_full(&self) -> bool {
self.layout.size() == self.offset
}
Expand Down Expand Up @@ -177,13 +180,13 @@ mod tests {
};
let mut reader =
StructReader::new(size_of_val(&expected_data), align_of_val(&expected_data));
let byte_buffer: &mut [u8; 8] = unsafe { std::mem::transmute(reader.ptr()) };
let byte_buffer: &mut [u8; 8] = unsafe { &mut *(reader.ptr_mut() as *mut [u8; 8]) };
byte_buffer[0] = 0b10000001;
byte_buffer[4] = 0x0;
byte_buffer[5] = 0xFF;
byte_buffer[6] = 0xFF;
byte_buffer[7] = 0x0;
let actual_data: Data = unsafe { reader.read() };
let actual_data: &Data = unsafe { reader.read() };
assert_eq!(actual_data.field_a, expected_data.field_a);
assert_eq!(actual_data.field_b, expected_data.field_b);
}
Expand Down