From 1ac108a7c52664d8ddd893685dd87d90b9cc5a2e Mon Sep 17 00:00:00 2001 From: Juyeong Maing Date: Fri, 13 Jun 2025 16:34:30 +0900 Subject: [PATCH] Add test for parsing concatenated messages --- client/src/main.rs | 27 ++++++++++++++++-- messages/Cargo.toml | 3 ++ messages/tests/partial_stream.rs | 47 ++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 messages/tests/partial_stream.rs diff --git a/client/src/main.rs b/client/src/main.rs index 35b35e5..77c0846 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -42,15 +42,21 @@ impl EventLoopWrapper { /// 서버 메시지 수신 태스크 (읽기 전용) async fn network_listener(mut read_half: OwnedReadHalf, tx: mpsc::Sender) { use tokio::io::AsyncReadExt; - let mut buffer = vec![0u8; 1024]; + + let mut read_buf = vec![0u8; 1024]; + let mut buf = Vec::::new(); + loop { - match read_half.read(&mut buffer).await { + match read_half.read(&mut read_buf).await { Ok(0) => { println!("Server disconnected."); break; } Ok(n) => { - if let Ok(msg) = bincode::deserialize::(&buffer[..n]) { + buf.extend_from_slice(&read_buf[..n]); + + while let Ok((msg, consumed)) = try_deser::(&buf) { + buf.drain(..consumed); if tx.send(msg).await.is_err() { eprintln!("Failed to send server message to channel"); } @@ -65,6 +71,21 @@ async fn network_listener(mut read_half: OwnedReadHalf, tx: mpsc::Sender(buf: &[u8]) -> Result<(T, usize), bincode::Error> { + let mut cur = std::io::Cursor::new(buf); + + match bincode::deserialize_from(&mut cur) { + Ok(m) => Ok((m, cur.position() as usize)), + + Err(e) => match *e { + bincode::ErrorKind::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::UnexpectedEof => { + Err(e) + } + _ => panic!("Invalid data during deserialization: {:?}", e), + }, + } +} + #[tokio::main] async fn main() { let server_addr = env::args().nth(1).unwrap_or_else(|| { diff --git a/messages/Cargo.toml b/messages/Cargo.toml index 22dbd57..96c1d86 100644 --- a/messages/Cargo.toml +++ b/messages/Cargo.toml @@ -6,3 +6,6 @@ edition = "2021" [dependencies] serde = { version = "1.0.217", features = ["derive"] } map-types = { version = "0.1.0", path = "../map_types" } + +[dev-dependencies] +bincode = "1.3" diff --git a/messages/tests/partial_stream.rs b/messages/tests/partial_stream.rs new file mode 100644 index 0000000..79ee3e8 --- /dev/null +++ b/messages/tests/partial_stream.rs @@ -0,0 +1,47 @@ +use messages::{ServerMessage, PlayerAction, PlayerPosition}; +use bincode; + +fn try_deser(buf: &[u8]) -> Result<(T, usize), bincode::Error> { + let mut cur = std::io::Cursor::new(buf); + match bincode::deserialize_from(&mut cur) { + Ok(m) => Ok((m, cur.position() as usize)), + Err(e) => match *e { + bincode::ErrorKind::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::UnexpectedEof => Err(e), + _ => panic!("Invalid data during deserialization: {:?}", e), + }, + } +} + +#[test] +fn parse_two_messages_from_partial_third() { + // prepare three messages + let msg1 = ServerMessage::Init { your_player_id: 1, your_position: PlayerPosition::NotInWorld }; + let msg2 = ServerMessage::PlayerAction { action: PlayerAction::DestroyBlock }; + let msg3 = ServerMessage::Init { your_player_id: 2, your_position: PlayerPosition::NotInWorld }; + + let mut buf = Vec::new(); + bincode::serialize_into(&mut buf, &msg1).unwrap(); + bincode::serialize_into(&mut buf, &msg2).unwrap(); + let partial = bincode::serialize(&msg3).unwrap(); + let cut = partial.len() / 2; + buf.extend_from_slice(&partial[..cut]); + + // first message + let (parsed1, consumed1) = try_deser::(&buf).unwrap(); + match parsed1 { + ServerMessage::Init { your_player_id, .. } => assert_eq!(your_player_id, 1), + _ => panic!("Unexpected variant for first message"), + } + buf.drain(..consumed1); + + // second message + let (parsed2, consumed2) = try_deser::(&buf).unwrap(); + match parsed2 { + ServerMessage::PlayerAction { action: PlayerAction::DestroyBlock } => {}, + _ => panic!("Unexpected variant for second message"), + } + buf.drain(..consumed2); + + // third message should be incomplete + assert!(try_deser::(&buf).is_err()); +}