diff --git a/substrate/core/network-libp2p/src/custom_proto.rs b/substrate/core/network-libp2p/src/custom_proto.rs index 9aa3d03e9e854b3b376b27123b0bff58090f2b69..6d0a5d5b1e8505f26f64a5a33c55ce35dbc3523e 100644 --- a/substrate/core/network-libp2p/src/custom_proto.rs +++ b/substrate/core/network-libp2p/src/custom_proto.rs @@ -15,22 +15,22 @@ // along with Substrate. If not, see <http://www.gnu.org/licenses/>. use bytes::{Bytes, BytesMut}; -use ProtocolId; use libp2p::core::{Multiaddr, ConnectionUpgrade, Endpoint}; -use PacketId; +use libp2p::tokio_codec::Framed; +use std::collections::VecDeque; use std::io::Error as IoError; use std::vec::IntoIter as VecIntoIter; -use futures::{future, Future, stream, Stream, Sink}; -use futures::sync::mpsc; +use futures::{prelude::*, future, stream, task}; use tokio_io::{AsyncRead, AsyncWrite}; use unsigned_varint::codec::UviBytes; +use ProtocolId; /// Connection upgrade for a single protocol. /// /// Note that "a single protocol" here refers to `par` for example. However /// each protocol can have multiple different versions for networking purposes. #[derive(Clone)] -pub struct RegisteredProtocol<T> { +pub struct RegisteredProtocol<TUserData> { /// Id of the protocol for API purposes. id: ProtocolId, /// Base name of the protocol as advertised on the network. @@ -41,67 +41,202 @@ pub struct RegisteredProtocol<T> { /// The packet count is used to filter out invalid messages. supported_versions: Vec<(u8, u8)>, /// Custom data. - custom_data: T, + custom_data: TUserData, } -/// Output of a `RegisteredProtocol` upgrade. -pub struct RegisteredProtocolOutput<T> { - /// Data passed to `RegisteredProtocol::new`. - pub custom_data: T, - - /// Id of the protocol. - pub protocol_id: ProtocolId, - - /// Endpoint of the connection. - pub endpoint: Endpoint, - - /// Version of the protocol that was negotiated. - pub protocol_version: u8, - - /// Channel to sender outgoing messages to. - // TODO: consider assembling packet_id here - pub outgoing: mpsc::UnboundedSender<Bytes>, - - /// Stream where incoming messages are received. The stream ends whenever - /// either side is closed. - pub incoming: Box<Stream<Item = (PacketId, Bytes), Error = IoError> + Send>, -} - -impl<T> RegisteredProtocol<T> { +impl<TUserData> RegisteredProtocol<TUserData> { /// Creates a new `RegisteredProtocol`. The `custom_data` parameter will be /// passed inside the `RegisteredProtocolOutput`. - pub fn new(custom_data: T, protocol: ProtocolId, versions: &[(u8, u8)]) + pub fn new(custom_data: TUserData, protocol: ProtocolId, versions: &[(u8, u8)]) -> Self { - let mut proto_name = Bytes::from_static(b"/substrate/"); - proto_name.extend_from_slice(&protocol); - proto_name.extend_from_slice(b"/"); + let mut base_name = Bytes::from_static(b"/substrate/"); + base_name.extend_from_slice(&protocol); + base_name.extend_from_slice(b"/"); RegisteredProtocol { - base_name: proto_name, + base_name, id: protocol, supported_versions: { let mut tmp: Vec<_> = versions.iter().rev().cloned().collect(); tmp.sort_unstable_by(|a, b| b.1.cmp(&a.1)); tmp }, - custom_data: custom_data, + custom_data, } } /// Returns the ID of the protocol. + #[inline] pub fn id(&self) -> ProtocolId { self.id } /// Returns the custom data that was passed to `new`. - pub fn custom_data(&self) -> &T { + #[inline] + pub fn custom_data(&self) -> &TUserData { &self.custom_data } } -// `Maf` is short for `MultiaddressFuture` -impl<T, C> ConnectionUpgrade<C> for RegisteredProtocol<T> -where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/ +/// Output of a `RegisteredProtocol` upgrade. +pub struct RegisteredProtocolSubstream<TSubstream> { + /// If true, we are in the process of closing the sink. + is_closing: bool, + /// Buffer of packets to send. + send_queue: VecDeque<Bytes>, + /// If true, we should call `poll_complete` on the inner sink. + requires_poll_complete: bool, + /// The underlying substream. + inner: stream::Fuse<Framed<TSubstream, UviBytes<Bytes>>>, + /// Maximum packet id. + packet_count: u8, + /// Id of the protocol. + protocol_id: ProtocolId, + /// Version of the protocol that was negotiated. + protocol_version: u8, + /// Task to notify when something is changed and we need to be polled. + to_notify: Option<task::Task>, +} + +/// Packet of data that can be sent or received. +#[derive(Debug, Clone)] +pub struct Packet { + /// Identifier of the packet. + pub id: u8, + /// The raw data. + pub data: Bytes, +} + +impl<TSubstream> RegisteredProtocolSubstream<TSubstream> { + /// Returns the protocol id. + #[inline] + pub fn protocol_id(&self) -> ProtocolId { + self.protocol_id + } + + /// Returns the version of the protocol that was negotiated. + #[inline] + pub fn protocol_version(&self) -> u8 { + self.protocol_version + } + + /// Starts a graceful shutdown process on this substream. + /// + /// Note that "graceful" means that we sent a closing message. We don't wait for any + /// confirmation from the remote. + /// + /// After calling this, the stream is guaranteed to finish soon-ish. + pub fn shutdown(&mut self) { + self.is_closing = true; + if let Some(task) = self.to_notify.take() { + task.notify(); + } + } + + /// Sends a message to the substream. + pub fn send_message(&mut self, Packet { id: packet_id, data }: Packet) { + if packet_id >= self.packet_count { + error!(target: "sub-libp2p", "Tried to send a packet with an invalid ID {}", packet_id); + return; + } + + let mut message = Bytes::with_capacity(1 + data.len()); + message.extend_from_slice(&[packet_id]); + message.extend_from_slice(&data); + self.send_queue.push_back(message); + + // If the length of the queue goes over a certain arbitrary threshold, we print a warning. + // TODO: figure out a good threshold + if self.send_queue.len() >= 2048 { + warn!(target: "sub-libp2p", "Queue of packets to send over substream is pretty \ + large: {}", self.send_queue.len()); + } + + if let Some(task) = self.to_notify.take() { + task.notify(); + } + } + + /// Turns raw data into a packet and checks whether it is valid. + fn data_to_packet(&self, mut data: BytesMut) -> Result<Packet, ()> { + // The `data` should be prefixed by the packet ID, therefore an empty packet is invalid. + if data.is_empty() { + debug!(target: "sub-libp2p", "ignoring incoming packet because it was empty"); + return Err(()); + } + + let packet = { + let id = data[0]; + let data = data.split_off(1); + Packet { id, data: data.freeze() } + }; + + if packet.id >= self.packet_count { + debug!(target: "sub-libp2p", "ignoring incoming packet because packet_id {} is \ + too large", packet.id); + return Err(()) + } + + Ok(packet) + } +} + +impl<TSubstream> Stream for RegisteredProtocolSubstream<TSubstream> +where TSubstream: AsyncRead + AsyncWrite, +{ + type Item = Packet; + type Error = IoError; + + fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { + // If we are closing, close as soon as the Sink is closed. + if self.is_closing { + return Ok(self.inner.close()?.map(|()| None)); + } + + // Flushing the local queue. + while let Some(packet) = self.send_queue.pop_front() { + match self.inner.start_send(packet)? { + AsyncSink::NotReady(packet) => { + self.send_queue.push_front(packet); + break; + }, + AsyncSink::Ready => self.requires_poll_complete = true, + } + } + + // Flushing if necessary. + if self.requires_poll_complete { + if let Async::Ready(()) = self.inner.poll_complete()? { + self.requires_poll_complete = false; + } + } + + // Receiving incoming packets. + // Note that `inner` is wrapped in a `Fuse`, therefore we can poll it forever. + loop { + match self.inner.poll()? { + Async::Ready(Some(data)) => + if let Ok(packet) = self.data_to_packet(data) { + return Ok(Async::Ready(Some(packet))) + }, + Async::Ready(None) => + if !self.requires_poll_complete && self.send_queue.is_empty() { + return Ok(Async::Ready(None)) + } else { + break + }, + Async::NotReady => break, + } + } + + self.to_notify = Some(task::current()); + Ok(Async::NotReady) + } +} + +impl<TSubstream, TUserData> ConnectionUpgrade<TSubstream> for RegisteredProtocol<TUserData> +where TSubstream: AsyncRead + AsyncWrite, + TUserData: Clone, { type NamesIter = VecIntoIter<(Bytes, Self::UpgradeIdentifier)>; type UpgradeIdentifier = u8; // Protocol version @@ -117,15 +252,15 @@ where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/ }).collect::<Vec<_>>().into_iter() } - type Output = RegisteredProtocolOutput<T>; + type Output = RegisteredProtocolSubstream<TSubstream>; type Future = future::FutureResult<Self::Output, IoError>; #[allow(deprecated)] fn upgrade( self, - socket: C, + socket: TSubstream, protocol_version: Self::UpgradeIdentifier, - endpoint: Endpoint, + _: Endpoint, _: &Multiaddr ) -> Self::Future { let packet_count = self.supported_versions @@ -134,103 +269,27 @@ where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/ .expect("negotiated protocol version that wasn't advertised ; \ programmer error") .1; - - // This function is called whenever we successfully negotiated a - // protocol with a remote (both if initiated by us or by the remote) - - // This channel is used to send outgoing packets to the custom_data - // for this open substream. - let (msg_tx, msg_rx) = mpsc::unbounded(); - - // Build the sink for outgoing network bytes, and the stream for - // incoming instructions. `stream` implements `Stream<Item = Message>`. - enum Message { - /// Received data from the network. - RecvSocket(BytesMut), - /// Data to send to the network. - /// The packet_id must already be inside the `Bytes`. - SendReq(Bytes), - /// The socket has been closed. - Finished, - } - - let (sink, stream) = { - let framed = AsyncRead::framed(socket, UviBytes::default()); - let msg_rx = msg_rx.map(Message::SendReq) - .map_err(|()| unreachable!("mpsc::UnboundedReceiver never errors")); - let (sink, stream) = framed.split(); - let stream = stream.map(Message::RecvSocket) - .chain(stream::once(Ok(Message::Finished))); - (sink, msg_rx.select(stream)) - }; - - let incoming = stream::unfold((sink, stream, false), move |(sink, stream, finished)| { - if finished { - return None - } - - Some(stream - .into_future() - .map_err(|(err, _)| err) - .and_then(move |(message, stream)| - match message { - Some(Message::RecvSocket(mut data)) => { - // The `data` should be prefixed by the packet ID, - // therefore an empty packet is invalid. - if data.is_empty() { - debug!(target: "sub-libp2p", "ignoring incoming \ - packet because it was empty"); - let f = future::ok((None, (sink, stream, false))); - return future::Either::A(f) - } - - let packet_id = data[0]; - let data = data.split_off(1); - - if packet_id >= packet_count { - debug!(target: "sub-libp2p", "ignoring incoming packet \ - because packet_id {} is too large", packet_id); - let f = future::ok((None, (sink, stream, false))); - future::Either::A(f) - } else { - let out = Some((packet_id, data.freeze())); - let f = future::ok((out, (sink, stream, false))); - future::Either::A(f) - } - }, - - Some(Message::SendReq(data)) => { - let fut = sink.send(data) - .map(move |sink| (None, (sink, stream, false))); - future::Either::B(fut) - }, - - Some(Message::Finished) | None => { - let f = future::ok((None, (sink, stream, true))); - future::Either::A(f) - }, - } - )) - }).filter_map(|v| v); - - let out = RegisteredProtocolOutput { - custom_data: self.custom_data, + + let framed = Framed::new(socket, UviBytes::default()); + + future::ok(RegisteredProtocolSubstream { + is_closing: false, + send_queue: VecDeque::new(), + requires_poll_complete: false, + inner: framed.fuse(), + packet_count, protocol_id: self.id, - endpoint, - protocol_version: protocol_version, - outgoing: msg_tx, - incoming: Box::new(incoming), - }; - - future::ok(out) + protocol_version, + to_notify: None, + }) } } // Connection upgrade for all the protocols contained in it. #[derive(Clone)] -pub struct RegisteredProtocols<T>(pub Vec<RegisteredProtocol<T>>); +pub struct RegisteredProtocols<TUserData>(pub Vec<RegisteredProtocol<TUserData>>); -impl<T> RegisteredProtocols<T> { +impl<TUserData> RegisteredProtocols<TUserData> { /// Returns the number of protocols. #[inline] pub fn len(&self) -> usize { @@ -239,7 +298,7 @@ impl<T> RegisteredProtocols<T> { /// Finds a protocol in the list by its id. pub fn find_protocol(&self, protocol: ProtocolId) - -> Option<&RegisteredProtocol<T>> { + -> Option<&RegisteredProtocol<TUserData>> { self.0.iter().find(|p| p.id == protocol) } @@ -249,35 +308,36 @@ impl<T> RegisteredProtocols<T> { } } -impl<T> Default for RegisteredProtocols<T> { +impl<TUserData> Default for RegisteredProtocols<TUserData> { fn default() -> Self { RegisteredProtocols(Vec::new()) } } -impl<T, C> ConnectionUpgrade<C> for RegisteredProtocols<T> -where C: AsyncRead + AsyncWrite + Send + 'static, // TODO: 'static :-/ +impl<TSubstream, TUserData> ConnectionUpgrade<TSubstream> for RegisteredProtocols<TUserData> +where TSubstream: AsyncRead + AsyncWrite, + TUserData: Clone, { type NamesIter = VecIntoIter<(Bytes, Self::UpgradeIdentifier)>; type UpgradeIdentifier = (usize, - <RegisteredProtocol<T> as ConnectionUpgrade<C>>::UpgradeIdentifier); + <RegisteredProtocol<TUserData> as ConnectionUpgrade<TSubstream>>::UpgradeIdentifier); fn protocol_names(&self) -> Self::NamesIter { // We concat the lists of `RegisteredProtocol::protocol_names` for // each protocol. self.0.iter().enumerate().flat_map(|(n, proto)| - ConnectionUpgrade::<C>::protocol_names(proto) + ConnectionUpgrade::<TSubstream>::protocol_names(proto) .map(move |(name, id)| (name, (n, id))) ).collect::<Vec<_>>().into_iter() } - type Output = <RegisteredProtocol<T> as ConnectionUpgrade<C>>::Output; - type Future = <RegisteredProtocol<T> as ConnectionUpgrade<C>>::Future; + type Output = <RegisteredProtocol<TUserData> as ConnectionUpgrade<TSubstream>>::Output; + type Future = <RegisteredProtocol<TUserData> as ConnectionUpgrade<TSubstream>>::Future; #[inline] fn upgrade( self, - socket: C, + socket: TSubstream, upgrade_identifier: Self::UpgradeIdentifier, endpoint: Endpoint, remote_addr: &Multiaddr diff --git a/substrate/core/network-libp2p/src/node_handler.rs b/substrate/core/network-libp2p/src/node_handler.rs index f0bed2f977bba82973f4830037ec751bac390793..8e6c50ae80bfc2682039f6f4b569538d1fb8cff3 100644 --- a/substrate/core/network-libp2p/src/node_handler.rs +++ b/substrate/core/network-libp2p/src/node_handler.rs @@ -15,7 +15,7 @@ // along with Substrate. If not, see <http://www.gnu.org/licenses/>. use bytes::Bytes; -use custom_proto::{RegisteredProtocols, RegisteredProtocolOutput}; +use custom_proto::{Packet, RegisteredProtocols, RegisteredProtocolSubstream}; use futures::{prelude::*, task}; use libp2p::core::{ConnectionUpgrade, Endpoint, PeerId, PublicKey, upgrade}; use libp2p::core::nodes::handled_node::{NodeHandler, NodeHandlerEndpoint, NodeHandlerEvent}; @@ -52,7 +52,7 @@ pub struct SubstrateNodeHandler<TSubstream, TUserData> { /// List of registered custom protocols. registered_custom: Arc<RegisteredProtocols<TUserData>>, /// Substreams open for "custom" protocols (eg. dot). - custom_protocols_substreams: Vec<RegisteredProtocolOutput<TUserData>>, + custom_protocols_substreams: Vec<RegisteredProtocolSubstream<TSubstream>>, /// Address of the node. address: Multiaddr, @@ -81,10 +81,10 @@ pub struct SubstrateNodeHandler<TSubstream, TUserData> { next_identify: Interval, /// Substreams being upgraded on the listening side. - upgrades_in_progress_listen: Vec<Box<Future<Item = FinalUpgrade<TSubstream, TUserData>, Error = IoError> + Send>>, + upgrades_in_progress_listen: Vec<Box<Future<Item = FinalUpgrade<TSubstream>, Error = IoError> + Send>>, /// Substreams being upgraded on the dialing side. Contrary to `upgrades_in_progress_listen`, /// these have a known purpose. - upgrades_in_progress_dial: Vec<(UpgradePurpose, Box<Future<Item = FinalUpgrade<TSubstream, TUserData>, Error = IoError> + Send>)>, + upgrades_in_progress_dial: Vec<(UpgradePurpose, Box<Future<Item = FinalUpgrade<TSubstream>, Error = IoError> + Send>)>, /// The substreams we want to open. queued_dial_upgrades: Vec<UpgradePurpose>, /// Number of outbound substreams the outside should open for us. @@ -396,6 +396,11 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, fn shutdown(&mut self) { // TODO: close gracefully self.is_shutting_down = true; + + for custom_proto in &mut self.custom_protocols_substreams { + custom_proto.shutdown(); + } + if let Some(to_notify) = self.to_notify.take() { to_notify.notify(); } @@ -403,6 +408,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, fn poll(&mut self) -> Poll<Option<NodeHandlerEvent<Self::OutboundOpenInfo, Self::OutEvent>>, IoError> { if self.is_shutting_down { + // TODO: finish only when everything is closed return Ok(Async::Ready(None)); } @@ -456,7 +462,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, ) { debug_assert!(self.registered_custom.has_protocol(protocol), "invalid protocol id requested in the API of the libp2p networking"); - let proto = match self.custom_protocols_substreams.iter().find(|p| p.protocol_id == protocol) { + let proto = match self.custom_protocols_substreams.iter_mut().find(|p| p.protocol_id() == protocol) { Some(proto) => proto, None => { // We are processing a message event before we could report to the outside that @@ -465,13 +471,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, }, }; - let mut message = Bytes::with_capacity(1 + data.len()); - message.extend_from_slice(&[packet_id]); - message.extend_from_slice(&data); - - if let Err(_) = proto.outgoing.unbounded_send(message) { - error!(target: "sub-libp2p", "Error while sending custom message to channel"); - } + proto.send_message(Packet { id: packet_id, data: data.into() }); } /// The node will try to open a Kademlia substream and produce a `KadOpen` event containing the @@ -521,7 +521,7 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, /// Optionally produces an event to dispatch. fn inject_fully_negotiated( &mut self, - upgrade: FinalUpgrade<TSubstream, TUserData> + upgrade: FinalUpgrade<TSubstream> ) -> Option<SubstrateOutEvent<TSubstream>> { match upgrade { FinalUpgrade::IdentifyListener(sender) => @@ -561,15 +561,15 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, } }, FinalUpgrade::Custom(proto) => { - self.cancel_dial_upgrade(&UpgradePurpose::Custom(proto.protocol_id)); - if self.custom_protocols_substreams.iter().any(|p| p.protocol_id == proto.protocol_id) { + self.cancel_dial_upgrade(&UpgradePurpose::Custom(proto.protocol_id())); + if self.custom_protocols_substreams.iter().any(|p| p.protocol_id() == proto.protocol_id()) { // Skipping protocol that's already open. return None; } let event = SubstrateOutEvent::CustomProtocolOpen { - protocol_id: proto.protocol_id, - version: proto.protocol_version, + protocol_id: proto.protocol_id(), + version: proto.protocol_version(), }; self.custom_protocols_substreams.push(proto); @@ -686,32 +686,32 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, // Poll for messages on the custom protocol stream. for n in (0 .. self.custom_protocols_substreams.len()).rev() { let mut custom_proto = self.custom_protocols_substreams.swap_remove(n); - match custom_proto.incoming.poll() { + match custom_proto.poll() { Ok(Async::NotReady) => self.custom_protocols_substreams.push(custom_proto), - Ok(Async::Ready(Some((packet_id, data)))) => { - let protocol_id = custom_proto.protocol_id; + Ok(Async::Ready(Some(Packet { id, data }))) => { + let protocol_id = custom_proto.protocol_id(); self.custom_protocols_substreams.push(custom_proto); return Ok(Async::Ready(Some(SubstrateOutEvent::CustomMessage { protocol_id, - packet_id, + packet_id: id, data, }))); }, Ok(Async::Ready(None)) => { // Trying to reopen the protocol. - self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id)); + self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id())); self.num_out_user_must_open += 1; return Ok(Async::Ready(Some(SubstrateOutEvent::CustomProtocolClosed { - protocol_id: custom_proto.protocol_id, + protocol_id: custom_proto.protocol_id(), result: Ok(()), }))) }, Err(err) => { // Trying to reopen the protocol. - self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id)); + self.queued_dial_upgrades.push(UpgradePurpose::Custom(custom_proto.protocol_id())); self.num_out_user_must_open += 1; return Ok(Async::Ready(Some(SubstrateOutEvent::CustomProtocolClosed { - protocol_id: custom_proto.protocol_id, + protocol_id: custom_proto.protocol_id(), result: Err(err), }))) }, @@ -851,16 +851,16 @@ where TSubstream: AsyncRead + AsyncWrite + Send + 'static, } /// Enum of all the possible protocols our service handles. -enum FinalUpgrade<TSubstream, TUserData> { +enum FinalUpgrade<TSubstream> { Kad(KadConnecController, Box<Stream<Item = KadIncomingRequest, Error = IoError> + Send>), IdentifyListener(identify::IdentifySender<TSubstream>), IdentifyDialer(identify::IdentifyInfo, Multiaddr), PingDialer(ping::PingDialer<TSubstream, Instant>), PingListener(ping::PingListener<TSubstream>), - Custom(RegisteredProtocolOutput<TUserData>), + Custom(RegisteredProtocolSubstream<TSubstream>), } -impl<TSubstream, TUserData> From<ping::PingOutput<TSubstream, Instant>> for FinalUpgrade<TSubstream, TUserData> { +impl<TSubstream> From<ping::PingOutput<TSubstream, Instant>> for FinalUpgrade<TSubstream> { fn from(out: ping::PingOutput<TSubstream, Instant>) -> Self { match out { ping::PingOutput::Ponger(ponger) => FinalUpgrade::PingListener(ponger), @@ -869,7 +869,7 @@ impl<TSubstream, TUserData> From<ping::PingOutput<TSubstream, Instant>> for Fina } } -impl<TSubstream, TUserData> From<identify::IdentifyOutput<TSubstream>> for FinalUpgrade<TSubstream, TUserData> { +impl<TSubstream> From<identify::IdentifyOutput<TSubstream>> for FinalUpgrade<TSubstream> { fn from(out: identify::IdentifyOutput<TSubstream>) -> Self { match out { identify::IdentifyOutput::RemoteInfo { info, observed_addr } =>