From 7afe69066f0c64ad754d44495d4183a7f903c56a Mon Sep 17 00:00:00 2001 From: Pierre Krieger <pierre.krieger1708@gmail.com> Date: Thu, 7 Feb 2019 17:16:41 +0100 Subject: [PATCH] Fix panic in custom protocol handler (#1723) --- .../src/custom_proto/behaviour.rs | 6 ++ .../src/custom_proto/handler.rs | 62 ++++++++++++++----- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/substrate/core/network-libp2p/src/custom_proto/behaviour.rs b/substrate/core/network-libp2p/src/custom_proto/behaviour.rs index b7575524c1e..66af5e048c5 100644 --- a/substrate/core/network-libp2p/src/custom_proto/behaviour.rs +++ b/substrate/core/network-libp2p/src/custom_proto/behaviour.rs @@ -542,6 +542,9 @@ where } } CustomProtosHandlerOut::CustomMessage { protocol_id, data } => { + debug_assert!(self.open_protocols.iter().any(|(s, p)| + s == &source && p == &protocol_id + )); let event = CustomProtosOut::CustomMessage { peer_id: source, protocol_id, @@ -551,6 +554,9 @@ where self.events.push(NetworkBehaviourAction::GenerateEvent(event)); } CustomProtosHandlerOut::Clogged { protocol_id, messages } => { + debug_assert!(self.open_protocols.iter().any(|(s, p)| + s == &source && p == &protocol_id + )); warn!(target: "sub-libp2p", "Queue of packets to send to {:?} (protocol: {:?}) is \ pretty large", source, protocol_id); self.events.push(NetworkBehaviourAction::GenerateEvent(CustomProtosOut::Clogged { diff --git a/substrate/core/network-libp2p/src/custom_proto/handler.rs b/substrate/core/network-libp2p/src/custom_proto/handler.rs index 198e51d7d06..51ff2408ad8 100644 --- a/substrate/core/network-libp2p/src/custom_proto/handler.rs +++ b/substrate/core/network-libp2p/src/custom_proto/handler.rs @@ -41,7 +41,7 @@ pub struct CustomProtosHandler<TSubstream> { protocols: RegisteredProtocols, /// See the documentation of `State`. - state: State, + state: State<TSubstream>, /// Value to be returned by `connection_keep_alive()`. keep_alive: KeepAlive, @@ -54,8 +54,12 @@ pub struct CustomProtosHandler<TSubstream> { } /// State of the handler. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum State { +enum State<TSubstream> { + /// Waiting for the behaviour to tell the handler whether it is enabled or disabled. + /// Contains a list of substreams opened by the remote and that we will integrate to + /// `substreams` only if we get enabled. + Init(SmallVec<[RegisteredProtocolSubstream<TSubstream>; 6]>), + /// Normal functionning. Normal, @@ -136,7 +140,7 @@ where protocols, // We keep the connection alive for at least 5 seconds, waiting for what happens. keep_alive: KeepAlive::Until(Instant::now() + Duration::from_secs(5)), - state: State::Normal, + state: State::Init(SmallVec::new()), substreams: SmallVec::new(), events_queue: SmallVec::new(), } @@ -147,23 +151,31 @@ where &mut self, proto: RegisteredProtocolSubstream<TSubstream>, ) { + if self.substreams.iter().any(|p| p.protocol_id() == proto.protocol_id()) { + // Skipping protocol that's already open. + return + } + match self.state { + State::Init(ref mut pending) => { + if pending.iter().all(|p| p.protocol_id() != proto.protocol_id()) { + pending.push(proto); + } + return + }, // TODO: we should shut down refused substreams gracefully; this should be fixed // at the same time as https://github.com/paritytech/substrate/issues/1517 State::Disabled | State::ShuttingDown => return, State::Normal => () } - if self.substreams.iter().any(|p| p.protocol_id() == proto.protocol_id()) { - // Skipping protocol that's already open. - return - } - let event = CustomProtosHandlerOut::CustomProtocolOpen { protocol_id: proto.protocol_id(), version: proto.protocol_version(), }; + self.keep_alive = KeepAlive::Forever; + self.substreams.push(proto); self.events_queue.push(ProtocolsHandlerEvent::Custom(event)); } @@ -206,7 +218,7 @@ where match message { CustomProtosHandlerIn::Disable => { match self.state { - State::Normal => self.state = State::Disabled, + State::Init(_) | State::Normal => self.state = State::Disabled, State::Disabled | State::ShuttingDown => (), } @@ -217,6 +229,19 @@ where }, CustomProtosHandlerIn::EnableActive | CustomProtosHandlerIn::EnablePassive => { match self.state { + State::Init(ref mut list) => { + for proto in list.drain() { + let event = CustomProtosHandlerOut::CustomProtocolOpen { + protocol_id: proto.protocol_id(), + version: proto.protocol_version(), + }; + + self.substreams.push(proto); + self.events_queue.push(ProtocolsHandlerEvent::Custom(event)); + } + + self.state = State::Normal; + } State::Disabled => self.state = State::Normal, State::Normal | State::ShuttingDown => (), } @@ -277,7 +302,7 @@ where fn shutdown(&mut self) { match self.state { - State::Normal | State::Disabled => self.state = State::ShuttingDown, + State::Init(_) | State::Normal | State::Disabled => self.state = State::ShuttingDown, State::ShuttingDown => (), } @@ -297,8 +322,10 @@ where return Ok(Async::Ready(event)) } - if self.state == State::ShuttingDown && self.substreams.is_empty() { - return Ok(Async::Ready(ProtocolsHandlerEvent::Shutdown)) + if let State::ShuttingDown = self.state { + if self.substreams.is_empty() { + return Ok(Async::Ready(ProtocolsHandlerEvent::Shutdown)) + } } for n in (0..self.substreams.len()).rev() { @@ -323,6 +350,10 @@ where Ok(Async::NotReady) => self.substreams.push(substream), Ok(Async::Ready(None)) => { + // Close the connection as soon as possible. + if self.substreams.is_empty() { + self.keep_alive = KeepAlive::Now; + } let event = CustomProtosHandlerOut::CustomProtocolClosed { protocol_id: substream.protocol_id(), result: Ok(()) @@ -330,6 +361,10 @@ where return Ok(Async::Ready(ProtocolsHandlerEvent::Custom(event))) }, Err(err) => { + // Close the connection as soon as possible. + if self.substreams.is_empty() { + self.keep_alive = KeepAlive::Now; + } let event = CustomProtosHandlerOut::CustomProtocolClosed { protocol_id: substream.protocol_id(), result: Err(err) @@ -350,7 +385,6 @@ where fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { f.debug_struct("CustomProtosHandler") .field("protocols", &self.protocols.len()) - .field("state", &self.state) .field("substreams", &self.substreams.len()) .finish() } -- GitLab