use std::{io, cmp}; use futures::{Future, Poll, Async}; use tokio_io::{AsyncRead, AsyncWrite}; use message::{Message, MessageResult, Error}; use message::types::{Version, Verack}; use network::Magic; use io::{write_message, WriteMessage, ReadMessage, read_message}; pub fn handshake(a: A, flags: u32, magic: Magic, version: Version, min_version: u32) -> Handshake where A: AsyncWrite + AsyncRead { Handshake { version: version.version(), nonce: version.nonce(), state: HandshakeState::SendVersion(write_message(a, version_message(magic, version))), magic: magic, min_version: min_version, flags: flags, } } pub fn accept_handshake(a: A, flags: u32, magic: Magic, version: Version, min_version: u32) -> AcceptHandshake where A: AsyncWrite + AsyncRead { AcceptHandshake { version: version.version(), nonce: version.nonce(), state: AcceptHandshakeState::ReceiveVersion { local_version: Some(version), future: read_message(a, flags, magic, 0), }, magic: magic, min_version: min_version, flags: flags, } } pub fn negotiate_version(local: u32, other: u32) -> u32 { cmp::min(local, other) } #[derive(Debug, PartialEq)] pub struct HandshakeResult { pub version: Version, pub negotiated_version: u32, } fn version_message(magic: Magic, version: Version) -> Message { Message::new(magic, version.version(), &version).expect("version message should always be serialized correctly") } fn verack_message(magic: Magic) -> Message { Message::new(magic, 0, &Verack).expect("verack message should always be serialized correctly") } enum HandshakeState { SendVersion(WriteMessage), ReceiveVersion(ReadMessage), SendVerack { version: Option, future: WriteMessage, }, ReceiveVerack { version: Option, future: ReadMessage, }, } enum AcceptHandshakeState { ReceiveVersion { local_version: Option, future: ReadMessage }, SendVersion { version: Option, future: WriteMessage, }, SendVerack { version: Option, future: WriteMessage, }, } pub struct Handshake { state: HandshakeState, magic: Magic, version: u32, nonce: Option, min_version: u32, flags: u32, } pub struct AcceptHandshake { state: AcceptHandshakeState, magic: Magic, version: u32, nonce: Option, min_version: u32, flags: u32, } impl Future for Handshake where A: AsyncRead + AsyncWrite { type Item = (A, MessageResult); type Error = io::Error; fn poll(&mut self) -> Poll { loop { let next_state = match self.state { HandshakeState::SendVersion(ref mut future) => { let (stream, _) = try_ready!(future.poll()); HandshakeState::ReceiveVersion(read_message(stream, self.flags, self.magic, 0)) }, HandshakeState::ReceiveVersion(ref mut future) => { let (stream, version) = try_ready!(future.poll()); let version = match version { Ok(version) => version, Err(err) => return Ok((stream, Err(err.into())).into()), }; if version.version() < self.min_version { return Ok((stream, Err(Error::InvalidVersion)).into()); } if let (Some(self_nonce), Some(nonce)) = (self.nonce, version.nonce()) { if self_nonce == nonce { return Ok((stream, Err(Error::InvalidVersion)).into()); } } HandshakeState::SendVerack { version: Some(version), future: write_message(stream, verack_message(self.magic)), } }, HandshakeState::SendVerack { ref mut version, ref mut future } => { let (stream, _) = try_ready!(future.poll()); let version = version.take().expect("verack must be preceded by version"); HandshakeState::ReceiveVerack { version: Some(version), future: read_message(stream, self.flags, self.magic, 0), } }, HandshakeState::ReceiveVerack { ref mut version, ref mut future } => { let (stream, _verack) = try_ready!(future.poll()); let version = version.take().expect("verack must be preceded by version"); let result = HandshakeResult { negotiated_version: negotiate_version(self.version, version.version()), version: version, }; return Ok(Async::Ready((stream, Ok(result)))); }, }; self.state = next_state; } } } impl Future for AcceptHandshake where A: AsyncRead + AsyncWrite { type Item = (A, MessageResult); type Error = io::Error; fn poll(&mut self) -> Poll { loop { let next_state = match self.state { AcceptHandshakeState::ReceiveVersion { ref mut local_version, ref mut future } => { let (stream, version) = try_ready!(future.poll()); let version = match version { Ok(version) => version, Err(err) => return Ok((stream, Err(err.into())).into()), }; if version.version() < self.min_version { return Ok((stream, Err(Error::InvalidVersion)).into()); } if let (Some(self_nonce), Some(nonce)) = (self.nonce, version.nonce()) { if self_nonce == nonce { return Ok((stream, Err(Error::InvalidVersion)).into()); } } let local_version = local_version.take().expect("local version must be set"); AcceptHandshakeState::SendVersion { version: Some(version), future: write_message(stream, version_message(self.magic, local_version)), } }, AcceptHandshakeState::SendVersion { ref mut version, ref mut future } => { let (stream, _) = try_ready!(future.poll()); AcceptHandshakeState::SendVerack { version: version.take(), future: write_message(stream, verack_message(self.magic)), } }, AcceptHandshakeState::SendVerack { ref mut version, ref mut future } => { let (stream, _) = try_ready!(future.poll()); let version = version.take().expect("verack must be preceded by version"); let result = HandshakeResult { negotiated_version: negotiate_version(self.version, version.version()), version: version, }; return Ok(Async::Ready((stream, Ok(result)))); }, }; self.state = next_state; } } } #[cfg(test)] mod tests { use std::io; use futures::{Future, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; use bytes::Bytes; use ser::Stream; use network::{Network, ConsensusFork, BitcoinCashConsensusParams}; use message::{Message, Error}; use message::types::Verack; use message::types::version::{Version, V0, V106, V70001}; use super::{handshake, accept_handshake, HandshakeResult}; pub struct TestIo { read: io::Cursor, write: Bytes, } impl io::Read for TestIo { fn read(&mut self, buf: &mut [u8]) -> io::Result { io::Read::read(&mut self.read, buf) } } impl AsyncRead for TestIo {} impl io::Write for TestIo { fn write(&mut self, buf: &[u8]) -> io::Result { io::Write::write(&mut self.write, buf) } fn flush(&mut self) -> io::Result<()> { io::Write::flush(&mut self.write) } } impl AsyncWrite for TestIo { fn shutdown(&mut self) -> Poll<(), io::Error> { Ok(().into()) } } fn local_version() -> Version { Version::V70001(V0 { version: 70001, services: 1u64.into(), timestamp: 0x4d1015e6, // address and port of remote // services set to 0, cause we know nothing about the node receiver: "00000000000000000000000000000000000000002f5a0808208d".into(), }, V106 { // our local address (not sure if it is valid, or if it is checked at all // services set to 0, because we support nothing from: "00000000000000000000000000000000000000007f000001208d".into(), nonce: 0x3c76a409eb48a227, user_agent: "pbtc".into(), start_height: 0, }, V70001 { relay: true, }) } fn remote_version() -> Version { Version::V70001(V0 { version: 70012, services: 1u64.into(), timestamp: 0x4d1015e6, // services set to 1, house receiver supports at least the network receiver: "010000000000000000000000000000000000ffffc2b5936adde9".into(), }, V106 { // remote address, port // and supported protocols from: "050000000000000000000000000000000000ffff2f5a0808208d".into(), nonce: 0x3c76a409eb48a228, user_agent: "/Satoshi:0.12.1/".into(), start_height: 0, }, V70001 { relay: true, }) } #[test] fn test_handshake() { let magic = Network::Mainnet.magic(&ConsensusFork::BitcoinCore); let version = 70012; let local_version = local_version(); let remote_version = remote_version(); let mut remote_stream = Stream::new(); remote_stream.append_slice(Message::new(magic, version, &remote_version).unwrap().as_ref()); remote_stream.append_slice(Message::new(magic, version, &Verack).unwrap().as_ref()); let expected = HandshakeResult { version: remote_version, negotiated_version: 70001, }; let mut expected_stream = Stream::new(); expected_stream.append_slice(Message::new(magic, version, &local_version).unwrap().as_ref()); expected_stream.append_slice(Message::new(magic, version, &Verack).unwrap().as_ref()); let test_io = TestIo { read: io::Cursor::new(remote_stream.out()), write: Bytes::default(), }; let hs = handshake(test_io, 0, magic, local_version, 0).wait().unwrap(); assert_eq!(hs.0.write, expected_stream.out()); assert_eq!(hs.1.unwrap(), expected); } #[test] fn test_accept_handshake() { let magic = Network::Mainnet.magic(&ConsensusFork::BitcoinCore); let version = 70012; let local_version = local_version(); let remote_version = remote_version(); let mut remote_stream = Stream::new(); remote_stream.append_slice(Message::new(magic, version, &remote_version).unwrap().as_ref()); let test_io = TestIo { read: io::Cursor::new(remote_stream.out()), write: Bytes::default(), }; let expected = HandshakeResult { version: remote_version, negotiated_version: 70001, }; let mut expected_stream = Stream::new(); expected_stream.append_slice(Message::new(magic, version, &local_version).unwrap().as_ref()); expected_stream.append_slice(Message::new(magic, version, &Verack).unwrap().as_ref()); let hs = accept_handshake(test_io, 0, magic, local_version, 0).wait().unwrap(); assert_eq!(hs.0.write, expected_stream.out()); assert_eq!(hs.1.unwrap(), expected); } #[test] fn test_self_handshake() { let magic = Network::Mainnet.magic(&ConsensusFork::BitcoinCore); let version = 70012; let remote_version = local_version(); let local_version = local_version(); let mut remote_stream = Stream::new(); remote_stream.append_slice(Message::new(magic, version, &remote_version).unwrap().as_ref()); let test_io = TestIo { read: io::Cursor::new(remote_stream.out()), write: Bytes::default(), }; let expected = Error::InvalidVersion; let hs = handshake(test_io, 0, magic, local_version, 0).wait().unwrap(); assert_eq!(hs.1.unwrap_err(), expected); } #[test] fn test_accept_self_handshake() { let magic = Network::Mainnet.magic(&ConsensusFork::BitcoinCore); let version = 70012; let remote_version = local_version(); let local_version = local_version(); let mut remote_stream = Stream::new(); remote_stream.append_slice(Message::new(magic, version, &remote_version).unwrap().as_ref()); let test_io = TestIo { read: io::Cursor::new(remote_stream.out()), write: Bytes::default(), }; let expected = Error::InvalidVersion; let hs = accept_handshake(test_io, 0, magic, local_version, 0).wait().unwrap(); assert_eq!(hs.1.unwrap_err(), expected); } #[test] fn test_fails_to_accept_other_fork_node() { let magic1 = Network::Mainnet.magic(&ConsensusFork::BitcoinCore); let magic2 = Network::Mainnet.magic(&ConsensusFork::BitcoinCash(BitcoinCashConsensusParams::new(Network::Mainnet))); let version = 70012; let local_version = local_version(); let remote_version = remote_version(); let mut remote_stream = Stream::new(); remote_stream.append_slice(Message::new(magic2, version, &remote_version).unwrap().as_ref()); let test_io = TestIo { read: io::Cursor::new(remote_stream.out()), write: Bytes::default(), }; let expected = Error::InvalidMagic; let hs = accept_handshake(test_io, 0, magic1, local_version, 0).wait().unwrap(); assert_eq!(hs.1.unwrap_err(), expected); } }