diff --git a/protocol/tests/round_trips.rs b/protocol/tests/round_trips.rs index 206aac3..d1e6aa0 100644 --- a/protocol/tests/round_trips.rs +++ b/protocol/tests/round_trips.rs @@ -154,6 +154,105 @@ fn hello_world_happy_path() { assert_eq!(message, decrypted_message[1..].to_vec()); // Skip header byte } +#[test] +#[cfg(feature = "std")] +fn pingpong_with_closed_connection_sync() { + use bip324::io::{Payload, Protocol}; + use bitcoin::consensus; + use p2p::message::{NetworkMessage, V2NetworkMessage}; + use std::net::{TcpListener, TcpStream}; + use std::thread; + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let server = thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + let reader = stream.try_clone().unwrap(); + let writer = stream; + + let mut protocol = Protocol::new( + p2p::Magic::REGTEST, + bip324::Role::Responder, + None, + None, + reader, + writer, + ) + .expect("Failed to create protocol"); + + // Read one message + let payload = protocol.read().expect("Failed to read payload"); + let received_message = consensus::deserialize::(payload.contents()) + .expect("Failed to deserialize"); + + if let NetworkMessage::Ping(x) = received_message.payload() { + let pong = V2NetworkMessage::new(NetworkMessage::Pong(*x)); + let message = consensus::serialize(&pong); + protocol + .write(&Payload::genuine(message)) + .expect("Failed to write pong"); + println!("Pong sent, stopping server."); + } else { + panic!("Expected Ping, but received: {received_message:?}"); + } + }); + + let stream = TcpStream::connect(addr).unwrap(); + let reader = stream.try_clone().unwrap(); + let writer = stream; + + println!("Starting sync BIP-324 handshake"); + let mut protocol = Protocol::new( + p2p::Magic::REGTEST, + bip324::Role::Initiator, + None, + None, + reader, + writer, + ) + .expect("Failed to create protocol"); + + println!("Sending Ping using sync Protocol::write()"); + let ping = V2NetworkMessage::new(NetworkMessage::Ping(45324)); + let message = consensus::serialize(&ping); + protocol + .write(&Payload::genuine(message)) + .expect("Failed to write ping"); + + println!("Reading response using sync Protocol::read()"); + let payload = protocol.read().expect("Failed to read response"); + let response_message = consensus::deserialize::(payload.contents()) + .expect("Failed to deserialize response"); + + assert_eq!(NetworkMessage::Pong(45324), *response_message.payload()); + + println!("Successfully ping-pong'ed using sync Protocol API!"); + server.join().unwrap(); + + println!( + "Trying to read another message from the server, while the connection is already closed." + ); + assert!(protocol.read().is_err()); + + println!("Writing to the closed socket for the first time should succeed."); + protocol + .write(&Payload::genuine(consensus::serialize( + &V2NetworkMessage::new(NetworkMessage::Ping(1)), + ))) + .expect("first write to closed socket should succeed!"); + + println!("Writing to the closed socket for the second time should fail."); + assert!( + protocol + .write(&Payload::genuine(consensus::serialize( + &V2NetworkMessage::new(NetworkMessage::Ping(2)) + ))) + .is_err(), + "second write to closed socket should fail!" + ); +} + #[tokio::test] #[cfg(feature = "tokio")] async fn pingpong_with_closed_connection_async() { @@ -229,6 +328,25 @@ async fn pingpong_with_closed_connection_async() { "Trying to read another message from the server, while the connection is already closed." ); assert!(protocol.read().await.is_err()); + + println!("Writing to the closed socket for the first time should succeed."); + protocol + .write(&Payload::genuine(consensus::serialize( + &V2NetworkMessage::new(NetworkMessage::Ping(1)), + ))) + .await + .expect("first write to closed socket should succeed!"); + + println!("Writing to the closed socket for the second time should fail."); + assert!( + protocol + .write(&Payload::genuine(consensus::serialize( + &V2NetworkMessage::new(NetworkMessage::Ping(2)) + ))) + .await + .is_err(), + "second write to closed socket should fail!" + ); } #[test]