1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
use std::io; use std::marker::PhantomData; use futures::{Poll, Future}; use tokio_io::AsyncRead; use tokio_io::io::{read_exact, ReadExact}; use bytes::Bytes; use hash::H32; use crypto::checksum; use message::{Error, MessageResult, Payload, deserialize_payload}; pub fn read_payload<M, A>(a: A, version: u32, len: usize, checksum: H32) -> ReadPayload<M, A> where A: AsyncRead, M: Payload { ReadPayload { reader: read_exact(a, Bytes::new_with_len(len)), version: version, checksum: checksum, payload_type: PhantomData, } } pub struct ReadPayload<M, A> { reader: ReadExact<A, Bytes>, version: u32, checksum: H32, payload_type: PhantomData<M>, } impl<M, A> Future for ReadPayload<M, A> where A: AsyncRead, M: Payload { type Item = (A, MessageResult<M>); type Error = io::Error; fn poll(&mut self) -> Poll<Self::Item, Self::Error> { let (read, data) = try_ready!(self.reader.poll()); if checksum(&data) != self.checksum { return Ok((read, Err(Error::InvalidChecksum)).into()); } let payload = deserialize_payload(&data, self.version); Ok((read, payload).into()) } } #[cfg(test)] mod tests { use futures::Future; use bytes::Bytes; use message::Error; use message::types::Ping; use super::read_payload; #[test] fn test_read_payload() { let raw: Bytes = "5845303b6da97786".into(); let ping = Ping::new(u64::from_str_radix("8677a96d3b304558", 16).unwrap()); assert_eq!(read_payload(raw.as_ref(), 0, 8, "83c00c76".into()).wait().unwrap().1, Ok(ping)); } #[test] fn test_read_payload_with_invalid_checksum() { let raw: Bytes = "5845303b6da97786".into(); assert_eq!(read_payload::<Ping, _>(raw.as_ref(), 0, 8, "83c00c75".into()).wait().unwrap().1, Err(Error::InvalidChecksum)); } #[test] fn test_read_too_short_payload() { let raw: Bytes = "5845303b6da977".into(); assert!(read_payload::<Ping, _>(raw.as_ref(), 0, 8, "83c00c76".into()).wait().is_err()); } }