Skip to content
Snippets Groups Projects
Commit 62d551a0 authored by Max Inden's avatar Max Inden Committed by GitHub
Browse files

client/network/req-resp: Prevent request id collision (#7957)

* client/network/req-resp: Add unit test for request id collision

* client/network/req-resp: Prevent request id collision

`RequestId` is a monotonically increasing integer, starting at
`1`. A `RequestId` is unique for a single `RequestResponse`
behaviour, but not across multiple `RequestResponse` behaviours. Thus
when handling `RequestId` in the context of multiple
`RequestResponse` behaviours, one needs to couple the protocol name
with the `RequestId` to get a unique request identifier.

This commit ensures that pending requests (`pending_requests`) and
pending responses (`pending_response_arrival_time`) are tracked both by
their protocol name and `RequestId`.

* client/network/req-resp: Remove unused import

* client/network/req-resp: Introduce ProtocolRequestId struct

* client/network/req-resp: Update test doc comment

Treat `RequestId` as an opaque type.

* client/network/req-resp: Improve expect proof
parent 48810cd7
No related merge requests found
......@@ -152,6 +152,24 @@ pub enum Event {
},
}
/// Combination of a protocol name and a request id.
///
/// Uniquely identifies an inbound or outbound request among all handled protocols. Note however
/// that uniqueness is only guaranteed between two inbound and likewise between two outbound
/// requests. There is no uniqueness guarantee in a set of both inbound and outbound
/// [`ProtocolRequestId`]s.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ProtocolRequestId {
protocol: Cow<'static, str>,
request_id: RequestId,
}
impl From<(Cow<'static, str>, RequestId)> for ProtocolRequestId {
fn from((protocol, request_id): (Cow<'static, str>, RequestId)) -> Self {
Self { protocol, request_id }
}
}
/// Implementation of `NetworkBehaviour` that provides support for request-response protocols.
pub struct RequestResponsesBehaviour {
/// The multiple sub-protocols, by name.
......@@ -163,7 +181,10 @@ pub struct RequestResponsesBehaviour {
>,
/// Pending requests, passed down to a [`RequestResponse`] behaviour, awaiting a reply.
pending_requests: HashMap<RequestId, (Instant, oneshot::Sender<Result<Vec<u8>, RequestFailure>>)>,
pending_requests: HashMap<
ProtocolRequestId,
(Instant, oneshot::Sender<Result<Vec<u8>, RequestFailure>>),
>,
/// Whenever an incoming request arrives, a `Future` is added to this list and will yield the
/// start time and the response to send back to the remote.
......@@ -172,7 +193,7 @@ pub struct RequestResponsesBehaviour {
>,
/// Whenever an incoming request arrives, the arrival [`Instant`] is recorded here.
pending_responses_arrival_time: HashMap<RequestId, Instant>,
pending_responses_arrival_time: HashMap<ProtocolRequestId, Instant>,
}
/// Generated by the response builder and waiting to be processed.
......@@ -226,14 +247,17 @@ impl RequestResponsesBehaviour {
pub fn send_request(
&mut self,
target: &PeerId,
protocol: &str,
protocol_name: &str,
request: Vec<u8>,
pending_response: oneshot::Sender<Result<Vec<u8>, RequestFailure>>,
) {
if let Some((protocol, _)) = self.protocols.get_mut(protocol) {
if let Some((protocol, _)) = self.protocols.get_mut(protocol_name) {
if protocol.is_connected(target) {
let request_id = protocol.send_request(target, request);
self.pending_requests.insert(request_id, (Instant::now(), pending_response));
self.pending_requests.insert(
(protocol_name.to_string().into(), request_id).into(),
(Instant::now(), pending_response),
);
} else {
if pending_response.send(Err(RequestFailure::NotConnected)).is_err() {
log::debug!(
......@@ -250,7 +274,7 @@ impl RequestResponsesBehaviour {
target: "sub-libp2p",
"Unknown protocol {:?}. At the same time local \
node is no longer interested in the result.",
protocol,
protocol_name,
);
};
}
......@@ -453,7 +477,7 @@ impl NetworkBehaviour for RequestResponsesBehaviour {
message: RequestResponseMessage::Request { request_id, request, channel, .. },
} => {
self.pending_responses_arrival_time.insert(
request_id.clone(),
(protocol.clone(), request_id.clone()).into(),
Instant::now(),
);
......@@ -502,7 +526,9 @@ impl NetworkBehaviour for RequestResponsesBehaviour {
},
..
} => {
let (started, delivered) = match self.pending_requests.remove(&request_id) {
let (started, delivered) = match self.pending_requests.remove(
&(protocol.clone(), request_id).into(),
) {
Some((started, pending_response)) => {
let delivered = pending_response.send(
response.map_err(|()| RequestFailure::Refused),
......@@ -537,7 +563,7 @@ impl NetworkBehaviour for RequestResponsesBehaviour {
error,
..
} => {
let started = match self.pending_requests.remove(&request_id) {
let started = match self.pending_requests.remove(&(protocol.clone(), request_id).into()) {
Some((started, pending_response)) => {
if pending_response.send(
Err(RequestFailure::Network(error.clone())),
......@@ -575,7 +601,9 @@ impl NetworkBehaviour for RequestResponsesBehaviour {
// An inbound request failed, either while reading the request or due to failing
// to send a response.
RequestResponseEvent::InboundFailure { request_id, peer, error, .. } => {
self.pending_responses_arrival_time.remove(&request_id);
self.pending_responses_arrival_time.remove(
&(protocol.clone(), request_id).into(),
);
let out = Event::InboundRequest {
peer,
protocol: protocol.clone(),
......@@ -583,10 +611,20 @@ impl NetworkBehaviour for RequestResponsesBehaviour {
};
return Poll::Ready(NetworkBehaviourAction::GenerateEvent(out));
}
// A response to an inbound request has been sent.
RequestResponseEvent::ResponseSent { request_id, peer } => {
let arrival_time = self.pending_responses_arrival_time.remove(&request_id)
let arrival_time = self.pending_responses_arrival_time.remove(
&(protocol.clone(), request_id).into(),
)
.map(|t| t.elapsed())
.expect("To find request arrival time for answered request.");
.expect(
"Time is added for each inbound request on arrival and only \
removed on success (`ResponseSent`) or failure \
(`InboundFailure`). One can not receive a success event for a \
request that either never arrived, or that has previously \
failed; qed.",
);
let out = Event::InboundRequest {
peer,
......@@ -765,9 +803,10 @@ impl RequestResponseCodec for GenericCodec {
#[cfg(test)]
mod tests {
use super::*;
use futures::channel::{mpsc, oneshot};
use futures::executor::LocalPool;
use futures::prelude::*;
use futures::task::Spawn;
use libp2p::identity::Keypair;
use libp2p::Multiaddr;
......@@ -777,6 +816,28 @@ mod tests {
use libp2p::swarm::{Swarm, SwarmEvent};
use std::{iter, time::Duration};
fn build_swarm(list: impl Iterator<Item = ProtocolConfig>) -> (Swarm<RequestResponsesBehaviour>, Multiaddr) {
let keypair = Keypair::generate_ed25519();
let noise_keys = noise::Keypair::<noise::X25519Spec>::new()
.into_authentic(&keypair)
.unwrap();
let transport = MemoryTransport
.upgrade(upgrade::Version::V1)
.authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated())
.multiplex(libp2p::yamux::YamuxConfig::default())
.boxed();
let behaviour = RequestResponsesBehaviour::new(list).unwrap();
let mut swarm = Swarm::new(transport, behaviour, keypair.public().into_peer_id());
let listen_addr: Multiaddr = format!("/memory/{}", rand::random::<u64>()).parse().unwrap();
Swarm::listen_on(&mut swarm, listen_addr.clone()).unwrap();
(swarm, listen_addr)
}
#[test]
fn basic_request_response_works() {
let protocol_name = "/test/req-resp/1";
......@@ -785,44 +846,24 @@ mod tests {
// Build swarms whose behaviour is `RequestResponsesBehaviour`.
let mut swarms = (0..2)
.map(|_| {
let keypair = Keypair::generate_ed25519();
let noise_keys = noise::Keypair::<noise::X25519Spec>::new()
.into_authentic(&keypair)
.unwrap();
let transport = MemoryTransport
.upgrade(upgrade::Version::V1)
.authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated())
.multiplex(libp2p::yamux::YamuxConfig::default())
.boxed();
let behaviour = {
let (tx, mut rx) = mpsc::channel(64);
let b = super::RequestResponsesBehaviour::new(iter::once(super::ProtocolConfig {
name: From::from(protocol_name),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
})).unwrap();
pool.spawner().spawn_obj(async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(b"this is a response".to_vec());
}
}.boxed().into()).unwrap();
let (tx, mut rx) = mpsc::channel::<IncomingRequest>(64);
b
pool.spawner().spawn_obj(async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(b"this is a response".to_vec());
}
}.boxed().into()).unwrap();
let protocol_config = ProtocolConfig {
name: From::from(protocol_name),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
};
let mut swarm = Swarm::new(transport, behaviour, keypair.public().into_peer_id());
let listen_addr: Multiaddr = format!("/memory/{}", rand::random::<u64>()).parse().unwrap();
Swarm::listen_on(&mut swarm, listen_addr.clone()).unwrap();
(swarm, listen_addr)
build_swarm(iter::once(protocol_config))
})
.collect::<Vec<_>>();
......@@ -839,7 +880,7 @@ mod tests {
async move {
loop {
match swarm.next_event().await {
SwarmEvent::Behaviour(super::Event::InboundRequest { result, .. }) => {
SwarmEvent::Behaviour(Event::InboundRequest { result, .. }) => {
result.unwrap();
},
_ => {}
......@@ -866,7 +907,7 @@ mod tests {
assert!(response_receiver.is_none());
response_receiver = Some(receiver);
}
SwarmEvent::Behaviour(super::Event::RequestFinished {
SwarmEvent::Behaviour(Event::RequestFinished {
result, ..
}) => {
result.unwrap();
......@@ -888,44 +929,24 @@ mod tests {
// Build swarms whose behaviour is `RequestResponsesBehaviour`.
let mut swarms = (0..2)
.map(|_| {
let keypair = Keypair::generate_ed25519();
let noise_keys = noise::Keypair::<noise::X25519Spec>::new()
.into_authentic(&keypair)
.unwrap();
let transport = MemoryTransport
.upgrade(upgrade::Version::V1)
.authenticate(noise::NoiseConfig::xx(noise_keys).into_authenticated())
.multiplex(libp2p::yamux::YamuxConfig::default())
.boxed();
let behaviour = {
let (tx, mut rx) = mpsc::channel(64);
let b = super::RequestResponsesBehaviour::new(iter::once(super::ProtocolConfig {
name: From::from(protocol_name),
max_request_size: 1024,
max_response_size: 8, // <-- important for the test
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
})).unwrap();
pool.spawner().spawn_obj(async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(b"this response exceeds the limit".to_vec());
}
}.boxed().into()).unwrap();
let (tx, mut rx) = mpsc::channel::<IncomingRequest>(64);
b
pool.spawner().spawn_obj(async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(b"this response exceeds the limit".to_vec());
}
}.boxed().into()).unwrap();
let protocol_config = ProtocolConfig {
name: From::from(protocol_name),
max_request_size: 1024,
max_response_size: 8, // <-- important for the test
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
};
let mut swarm = Swarm::new(transport, behaviour, keypair.public().into_peer_id());
let listen_addr: Multiaddr = format!("/memory/{}", rand::random::<u64>()).parse().unwrap();
Swarm::listen_on(&mut swarm, listen_addr.clone()).unwrap();
(swarm, listen_addr)
build_swarm(iter::once(protocol_config))
})
.collect::<Vec<_>>();
......@@ -943,7 +964,7 @@ mod tests {
async move {
loop {
match swarm.next_event().await {
SwarmEvent::Behaviour(super::Event::InboundRequest { result, .. }) => {
SwarmEvent::Behaviour(Event::InboundRequest { result, .. }) => {
assert!(result.is_ok());
break
},
......@@ -971,7 +992,7 @@ mod tests {
assert!(response_receiver.is_none());
response_receiver = Some(receiver);
}
SwarmEvent::Behaviour(super::Event::RequestFinished {
SwarmEvent::Behaviour(Event::RequestFinished {
result, ..
}) => {
assert!(result.is_err());
......@@ -982,9 +1003,153 @@ mod tests {
}
match response_receiver.unwrap().await.unwrap().unwrap_err() {
super::RequestFailure::Network(super::OutboundFailure::ConnectionClosed) => {},
RequestFailure::Network(OutboundFailure::ConnectionClosed) => {},
_ => panic!()
}
});
}
/// A [`RequestId`] is a unique identifier among either all inbound or all outbound requests for
/// a single [`RequestResponse`] behaviour. It is not guaranteed to be unique across multiple
/// [`RequestResponse`] behaviours. Thus when handling [`RequestId`] in the context of multiple
/// [`RequestResponse`] behaviours, one needs to couple the protocol name with the [`RequestId`]
/// to get a unique request identifier.
///
/// This test ensures that two requests on different protocols can be handled concurrently
/// without a [`RequestId`] collision.
///
/// See [`ProtocolRequestId`] for additional information.
#[test]
fn request_id_collision() {
let protocol_name_1 = "/test/req-resp-1/1";
let protocol_name_2 = "/test/req-resp-2/1";
let mut pool = LocalPool::new();
let mut swarm_1 = {
let protocol_configs = vec![
ProtocolConfig {
name: From::from(protocol_name_1),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
},
ProtocolConfig {
name: From::from(protocol_name_2),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
},
];
build_swarm(protocol_configs.into_iter()).0
};
let (mut swarm_2, mut swarm_2_handler_1, mut swarm_2_handler_2, listen_add_2) = {
let (tx_1, rx_1) = mpsc::channel(64);
let (tx_2, rx_2) = mpsc::channel(64);
let protocol_configs = vec![
ProtocolConfig {
name: From::from(protocol_name_1),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx_1),
},
ProtocolConfig {
name: From::from(protocol_name_2),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx_2),
},
];
let (swarm, listen_addr) = build_swarm(protocol_configs.into_iter());
(swarm, rx_1, rx_2, listen_addr)
};
// Ask swarm 1 to dial swarm 2. There isn't any discovery mechanism in place in this test,
// so they wouldn't connect to each other.
Swarm::dial_addr(&mut swarm_1, listen_add_2).unwrap();
// Run swarm 2 in the background, receiving two requests.
pool.spawner().spawn_obj(
async move {
loop {
match swarm_2.next_event().await {
SwarmEvent::Behaviour(Event::InboundRequest { result, .. }) => {
result.unwrap();
},
_ => {}
}
}
}.boxed().into()
).unwrap();
// Handle both requests sent by swarm 1 to swarm 2 in the background.
//
// Make sure both requests overlap, by answering the first only after receiving the
// second.
pool.spawner().spawn_obj(async move {
let protocol_1_request = swarm_2_handler_1.next().await;
let protocol_2_request = swarm_2_handler_2.next().await;
protocol_1_request.unwrap()
.pending_response
.send(b"this is a response".to_vec())
.unwrap();
protocol_2_request.unwrap()
.pending_response
.send(b"this is a response".to_vec())
.unwrap();
}.boxed().into()).unwrap();
// Have swarm 1 send two requests to swarm 2 and await responses.
pool.run_until(
async move {
let mut response_receivers = None;
let mut num_responses = 0;
loop {
match swarm_1.next_event().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let (sender_1, receiver_1) = oneshot::channel();
let (sender_2, receiver_2) = oneshot::channel();
swarm_1.send_request(
&peer_id,
protocol_name_1,
b"this is a request".to_vec(),
sender_1,
);
swarm_1.send_request(
&peer_id,
protocol_name_2,
b"this is a request".to_vec(),
sender_2,
);
assert!(response_receivers.is_none());
response_receivers = Some((receiver_1, receiver_2));
}
SwarmEvent::Behaviour(Event::RequestFinished {
result, ..
}) => {
num_responses += 1;
result.unwrap();
if num_responses == 2 {
break;
}
}
_ => {}
}
}
let (response_receiver_1, response_receiver_2) = response_receivers.unwrap();
assert_eq!(response_receiver_1.await.unwrap().unwrap(), b"this is a response");
assert_eq!(response_receiver_2.await.unwrap().unwrap(), b"this is a response");
}
);
}
}
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment