diff --git a/build.rs b/build.rs index 6d72c66b..66ac3968 100644 --- a/build.rs +++ b/build.rs @@ -1,14 +1,14 @@ fn main() { - prost_build::compile_protos( - &[ - "src/schema/keys.proto", - "src/schema/noise.proto", - "src/schema/webrtc.proto", - "src/protocol/libp2p/schema/identify.proto", - "src/protocol/libp2p/schema/kademlia.proto", - "src/protocol/libp2p/schema/bitswap.proto", - ], - &["src"], - ) - .unwrap(); + prost_build::compile_protos( + &[ + "src/schema/keys.proto", + "src/schema/noise.proto", + "src/schema/webrtc.proto", + "src/protocol/libp2p/schema/identify.proto", + "src/protocol/libp2p/schema/kademlia.proto", + "src/protocol/libp2p/schema/bitswap.proto", + ], + &["src"], + ) + .unwrap(); } diff --git a/examples/custom_executor.rs b/examples/custom_executor.rs index 415ec0e0..34ab4748 100644 --- a/examples/custom_executor.rs +++ b/examples/custom_executor.rs @@ -28,11 +28,11 @@ //! Run: `RUST_LOG=info cargo run --example custom_executor` use litep2p::{ - config::ConfigBuilder, - executor::Executor, - protocol::libp2p::ping::{Config as PingConfig, PingEvent}, - transport::tcp::config::Config as TcpConfig, - Litep2p, + config::ConfigBuilder, + executor::Executor, + protocol::libp2p::ping::{Config as PingConfig, PingEvent}, + transport::tcp::config::Config as TcpConfig, + Litep2p, }; use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; @@ -44,102 +44,112 @@ use std::{future::Future, pin::Pin, sync::Arc}; /// /// Just a wrapper around `FuturesUnordered` which receives the futures over `mpsc::Receiver`. struct TaskExecutor { - rx: Receiver + Send>>>, - futures: FuturesUnordered>, + rx: Receiver + Send>>>, + futures: FuturesUnordered>, } impl TaskExecutor { - /// Create new [`TaskExecutor`]. - fn new() -> (Self, Sender + Send>>>) { - let (tx, rx) = channel(64); - - (Self { rx, futures: FuturesUnordered::new() }, tx) - } - - /// Drive the futures forward and poll the receiver for any new futures. - async fn next(&mut self) { - loop { - tokio::select! { - future = self.rx.recv() => self.futures.push(future.unwrap()), - _ = self.futures.next(), if !self.futures.is_empty() => {} - } - } - } + /// Create new [`TaskExecutor`]. + fn new() -> (Self, Sender + Send>>>) { + let (tx, rx) = channel(64); + + ( + Self { + rx, + futures: FuturesUnordered::new(), + }, + tx, + ) + } + + /// Drive the futures forward and poll the receiver for any new futures. + async fn next(&mut self) { + loop { + tokio::select! { + future = self.rx.recv() => self.futures.push(future.unwrap()), + _ = self.futures.next(), if !self.futures.is_empty() => {} + } + } + } } struct TaskExecutorHandle { - tx: Sender + Send>>>, + tx: Sender + Send>>>, } impl Executor for TaskExecutorHandle { - fn run(&self, future: Pin + Send>>) { - let _ = self.tx.try_send(future); - } + fn run(&self, future: Pin + Send>>) { + let _ = self.tx.try_send(future); + } - fn run_with_name(&self, _: &'static str, future: Pin + Send>>) { - let _ = self.tx.try_send(future); - } + fn run_with_name(&self, _: &'static str, future: Pin + Send>>) { + let _ = self.tx.try_send(future); + } } -fn make_litep2p() -> (Litep2p, TaskExecutor, Box + Send + Unpin>) { - let (executor, sender) = TaskExecutor::new(); - let (ping_config, ping_event_stream) = PingConfig::default(); - - let litep2p = Litep2p::new( - ConfigBuilder::new() - .with_executor(Arc::new(TaskExecutorHandle { tx: sender.clone() })) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config) - .build(), - ) - .unwrap(); - - (litep2p, executor, ping_event_stream) +fn make_litep2p() -> ( + Litep2p, + TaskExecutor, + Box + Send + Unpin>, +) { + let (executor, sender) = TaskExecutor::new(); + let (ping_config, ping_event_stream) = PingConfig::default(); + + let litep2p = Litep2p::new( + ConfigBuilder::new() + .with_executor(Arc::new(TaskExecutorHandle { tx: sender.clone() })) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config) + .build(), + ) + .unwrap(); + + (litep2p, executor, ping_event_stream) } #[tokio::main] async fn main() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - // create two identical litep2ps - let (mut litep2p1, mut executor1, mut ping_event_stream1) = make_litep2p(); - let (mut litep2p2, mut executor2, mut ping_event_stream2) = make_litep2p(); - - // dial `litep2p1` - litep2p2 - .dial_address(litep2p1.listen_addresses().next().unwrap().clone()) - .await - .unwrap(); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = executor1.next() => {} - _ = litep2p1.next_event() => {}, - _ = ping_event_stream1.next() => {}, - } - } - }); - - // poll litep2p, task executor and ping event stream all together - // - // since a custom task executor was provided, it's now the user's responsibility - // to actually make sure to poll those futures so that litep2p can make progress - loop { - tokio::select! { - _ = executor2.next() => {} - _ = litep2p2.next_event() => {}, - event = ping_event_stream2.next() => match event { - Some(PingEvent::Ping { peer, ping }) => tracing::info!( - "ping time with {peer:?}: {ping:?}" - ), - _ => {} - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + // create two identical litep2ps + let (mut litep2p1, mut executor1, mut ping_event_stream1) = make_litep2p(); + let (mut litep2p2, mut executor2, mut ping_event_stream2) = make_litep2p(); + + // dial `litep2p1` + litep2p2 + .dial_address(litep2p1.listen_addresses().next().unwrap().clone()) + .await + .unwrap(); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = executor1.next() => {} + _ = litep2p1.next_event() => {}, + _ = ping_event_stream1.next() => {}, + } + } + }); + + // poll litep2p, task executor and ping event stream all together + // + // since a custom task executor was provided, it's now the user's responsibility + // to actually make sure to poll those futures so that litep2p can make progress + loop { + tokio::select! { + _ = executor2.next() => {} + _ = litep2p2.next_event() => {}, + event = ping_event_stream2.next() => match event { + Some(PingEvent::Ping { peer, ping }) => tracing::info!( + "ping time with {peer:?}: {ping:?}" + ), + _ => {} + } + } + } } diff --git a/examples/custom_protocol.rs b/examples/custom_protocol.rs index c9cee2dc..6c5b3fe4 100644 --- a/examples/custom_protocol.rs +++ b/examples/custom_protocol.rs @@ -21,11 +21,11 @@ //! This example demonstrates how to implement a custom protocol for litep2p. use litep2p::{ - codec::ProtocolCodec, - config::ConfigBuilder, - protocol::{Direction, TransportEvent, TransportService, UserProtocol}, - types::protocol::ProtocolName, - Litep2p, PeerId, + codec::ProtocolCodec, + config::ConfigBuilder, + protocol::{Direction, TransportEvent, TransportService, UserProtocol}, + types::protocol::ProtocolName, + Litep2p, PeerId, }; use bytes::{Buf, BufMut, BytesMut}; @@ -39,268 +39,276 @@ use std::collections::{hash_map::Entry, HashMap}; struct CustomCodec; impl Decoder for CustomCodec { - type Item = BytesMut; - type Error = litep2p::Error; + type Item = BytesMut; + type Error = litep2p::Error; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if src.is_empty() { - return Ok(None); - } + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.is_empty() { + return Ok(None); + } - let len = src.get_u8() as usize; - if src.len() >= len { - let mut out = BytesMut::with_capacity(len); - out.put_slice(&src[..len]); - src.advance(len); + let len = src.get_u8() as usize; + if src.len() >= len { + let mut out = BytesMut::with_capacity(len); + out.put_slice(&src[..len]); + src.advance(len); - return Ok(Some(out)); - } + return Ok(Some(out)); + } - Ok(None) - } + Ok(None) + } } impl Encoder for CustomCodec { - type Error = std::io::Error; + type Error = std::io::Error; - fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> Result<(), Self::Error> { - if item.len() > u8::MAX as usize { - return Err(std::io::ErrorKind::PermissionDenied.into()); - } + fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> Result<(), Self::Error> { + if item.len() > u8::MAX as usize { + return Err(std::io::ErrorKind::PermissionDenied.into()); + } - dst.put_u8(item.len() as u8); - dst.extend(&item); + dst.put_u8(item.len() as u8); + dst.extend(&item); - Ok(()) - } + Ok(()) + } } /// Events received from the protocol. #[derive(Debug)] enum CustomProtocolEvent { - /// Received `message` from `peer`. - MessageReceived { - /// Peer ID. - peer: PeerId, - - /// Message. - message: Vec, - }, + /// Received `message` from `peer`. + MessageReceived { + /// Peer ID. + peer: PeerId, + + /// Message. + message: Vec, + }, } /// Commands sent to the protocol. #[derive(Debug)] enum CustomProtocolCommand { - /// Send `message` to `peer`. - SendMessage { - /// Peer ID. - peer: PeerId, - - /// Message. - message: Vec, - }, + /// Send `message` to `peer`. + SendMessage { + /// Peer ID. + peer: PeerId, + + /// Message. + message: Vec, + }, } /// Handle for communicating with the protocol. #[derive(Debug)] struct CustomProtocolHandle { - cmd_tx: Sender, - event_rx: Receiver, + cmd_tx: Sender, + event_rx: Receiver, } #[derive(Debug)] struct CustomProtocol { - /// Channel for receiving commands from user. - cmd_rx: Receiver, + /// Channel for receiving commands from user. + cmd_rx: Receiver, - /// Channel for sending events to user. - event_tx: Sender, + /// Channel for sending events to user. + event_tx: Sender, - /// Connected peers. - peers: HashMap>>, + /// Connected peers. + peers: HashMap>>, - /// Active inbound substreams. - inbound: FuturesUnordered>)>>, + /// Active inbound substreams. + inbound: FuturesUnordered>)>>, - /// Active outbound substreams. - outbound: FuturesUnordered>>, + /// Active outbound substreams. + outbound: FuturesUnordered>>, } impl CustomProtocol { - /// Create new [`CustomProtocol`]. - pub fn new() -> (Self, CustomProtocolHandle) { - let (event_tx, event_rx) = channel(64); - let (cmd_tx, cmd_rx) = channel(64); - - ( - Self { - cmd_rx, - event_tx, - peers: HashMap::new(), - inbound: FuturesUnordered::new(), - outbound: FuturesUnordered::new(), - }, - CustomProtocolHandle { cmd_tx, event_rx }, - ) - } + /// Create new [`CustomProtocol`]. + pub fn new() -> (Self, CustomProtocolHandle) { + let (event_tx, event_rx) = channel(64); + let (cmd_tx, cmd_rx) = channel(64); + + ( + Self { + cmd_rx, + event_tx, + peers: HashMap::new(), + inbound: FuturesUnordered::new(), + outbound: FuturesUnordered::new(), + }, + CustomProtocolHandle { cmd_tx, event_rx }, + ) + } } #[async_trait::async_trait] impl UserProtocol for CustomProtocol { - fn protocol(&self) -> ProtocolName { - ProtocolName::from("/custom-protocol/1") - } - - // Protocol code is set to `Unspecified` which means that `litep2p` won't provide - // `Sink + Stream` for the protocol and instead only `AsyncWrite + AsyncRead` are provided. - // User must implement their custom codec on top of `Substream` using, e.g., - // `tokio_codec::Framed` if they want to have message framing. - fn codec(&self) -> ProtocolCodec { - ProtocolCodec::Unspecified - } - - /// Start running event loop for [`CustomProtocol`]. - async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { - loop { - tokio::select! { - cmd = self.cmd_rx.recv() => match cmd { - Some(CustomProtocolCommand::SendMessage { peer, message }) => { - match self.peers.entry(peer) { - // peer doens't exist so dial them and save the message - Entry::Vacant(entry) => match service.dial(&peer) { - Ok(()) => { - entry.insert(Some(message)); - } - Err(error) => { - eprintln!("failed to dial {peer:?}: {error:?}"); - } - } - // peer exists so open a new substream - Entry::Occupied(mut entry) => match service.open_substream(peer) { - Ok(_) => { - entry.insert(Some(message)); - } - Err(error) => { - eprintln!("failed to open substream to {peer:?}: {error:?}"); - } - } - } - } - None => return Err(litep2p::Error::EssentialTaskClosed), - }, - event = service.next() => match event { - // connection established to peer - // - // check if the peer already exist in the protocol with a pending message - // and if yes, open substream to the peer. - Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - match self.peers.get(&peer) { - Some(Some(_)) => { - if let Err(error) = service.open_substream(peer) { - println!("failed to open substream to {peer:?}: {error:?}"); - } - } - Some(None) => {} - None => { - self.peers.insert(peer, None); - } - } - } - // substream opened - // - // for inbound substreams, move the substream to `self.inbound` and poll them for messages - // - // for outbound substreams, move the substream to `self.outbound` and send the saved message to remote peer - Some(TransportEvent::SubstreamOpened { peer, substream, direction, .. }) => { - match direction { - Direction::Inbound => { - self.inbound.push(Box::pin(async move { - (peer, Framed::new(substream, CustomCodec).next().await) - })); - } - Direction::Outbound(_) => { - let message = self.peers.get_mut(&peer).expect("peer to exist").take().unwrap(); - - self.outbound.push(Box::pin(async move { - let mut framed = Framed::new(substream, CustomCodec); - framed.send(BytesMut::from(&message[..])).await.map_err(From::from) - })); - } - } - } - // connection closed, remove all peer context - Some(TransportEvent::ConnectionClosed { peer }) => { - self.peers.remove(&peer); - } - None => return Err(litep2p::Error::EssentialTaskClosed), - _ => {}, - }, - // poll inbound substreams for messages - event = self.inbound.next(), if !self.inbound.is_empty() => match event { - Some((peer, Some(Ok(message)))) => { - self.event_tx.send(CustomProtocolEvent::MessageReceived { - peer, - message: message.into(), - }).await.unwrap(); - } - event => eprintln!("failed to read message from an inbound substream: {event:?}"), - }, - // poll outbound substreams so that they can make progress - _ = self.outbound.next(), if !self.outbound.is_empty() => {} - } - } - } + fn protocol(&self) -> ProtocolName { + ProtocolName::from("/custom-protocol/1") + } + + // Protocol code is set to `Unspecified` which means that `litep2p` won't provide + // `Sink + Stream` for the protocol and instead only `AsyncWrite + AsyncRead` are provided. + // User must implement their custom codec on top of `Substream` using, e.g., + // `tokio_codec::Framed` if they want to have message framing. + fn codec(&self) -> ProtocolCodec { + ProtocolCodec::Unspecified + } + + /// Start running event loop for [`CustomProtocol`]. + async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { + loop { + tokio::select! { + cmd = self.cmd_rx.recv() => match cmd { + Some(CustomProtocolCommand::SendMessage { peer, message }) => { + match self.peers.entry(peer) { + // peer doens't exist so dial them and save the message + Entry::Vacant(entry) => match service.dial(&peer) { + Ok(()) => { + entry.insert(Some(message)); + } + Err(error) => { + eprintln!("failed to dial {peer:?}: {error:?}"); + } + } + // peer exists so open a new substream + Entry::Occupied(mut entry) => match service.open_substream(peer) { + Ok(_) => { + entry.insert(Some(message)); + } + Err(error) => { + eprintln!("failed to open substream to {peer:?}: {error:?}"); + } + } + } + } + None => return Err(litep2p::Error::EssentialTaskClosed), + }, + event = service.next() => match event { + // connection established to peer + // + // check if the peer already exist in the protocol with a pending message + // and if yes, open substream to the peer. + Some(TransportEvent::ConnectionEstablished { peer, .. }) => { + match self.peers.get(&peer) { + Some(Some(_)) => { + if let Err(error) = service.open_substream(peer) { + println!("failed to open substream to {peer:?}: {error:?}"); + } + } + Some(None) => {} + None => { + self.peers.insert(peer, None); + } + } + } + // substream opened + // + // for inbound substreams, move the substream to `self.inbound` and poll them for messages + // + // for outbound substreams, move the substream to `self.outbound` and send the saved message to remote peer + Some(TransportEvent::SubstreamOpened { peer, substream, direction, .. }) => { + match direction { + Direction::Inbound => { + self.inbound.push(Box::pin(async move { + (peer, Framed::new(substream, CustomCodec).next().await) + })); + } + Direction::Outbound(_) => { + let message = self.peers.get_mut(&peer).expect("peer to exist").take().unwrap(); + + self.outbound.push(Box::pin(async move { + let mut framed = Framed::new(substream, CustomCodec); + framed.send(BytesMut::from(&message[..])).await.map_err(From::from) + })); + } + } + } + // connection closed, remove all peer context + Some(TransportEvent::ConnectionClosed { peer }) => { + self.peers.remove(&peer); + } + None => return Err(litep2p::Error::EssentialTaskClosed), + _ => {}, + }, + // poll inbound substreams for messages + event = self.inbound.next(), if !self.inbound.is_empty() => match event { + Some((peer, Some(Ok(message)))) => { + self.event_tx.send(CustomProtocolEvent::MessageReceived { + peer, + message: message.into(), + }).await.unwrap(); + } + event => eprintln!("failed to read message from an inbound substream: {event:?}"), + }, + // poll outbound substreams so that they can make progress + _ = self.outbound.next(), if !self.outbound.is_empty() => {} + } + } + } } fn make_litep2p() -> (Litep2p, CustomProtocolHandle) { - let (custom_protocol, handle) = CustomProtocol::new(); - - ( - Litep2p::new( - ConfigBuilder::new() - .with_tcp(Default::default()) - .with_user_protocol(Box::new(custom_protocol)) - .build(), - ) - .unwrap(), - handle, - ) + let (custom_protocol, handle) = CustomProtocol::new(); + + ( + Litep2p::new( + ConfigBuilder::new() + .with_tcp(Default::default()) + .with_user_protocol(Box::new(custom_protocol)) + .build(), + ) + .unwrap(), + handle, + ) } #[tokio::main] async fn main() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut litep2p1, handle1) = make_litep2p(); - let (mut litep2p2, mut handle2) = make_litep2p(); - - let peer2 = *litep2p2.local_peer_id(); - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.add_known_address(peer2, std::iter::once(listen_address)); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - } - } - }); - - for message in - vec![b"hello, world".to_vec(), b"testing 123".to_vec(), b"goodbye, world".to_vec()] - { - handle1 - .cmd_tx - .send(CustomProtocolCommand::SendMessage { peer: peer2, message }) - .await - .unwrap(); - - let CustomProtocolEvent::MessageReceived { peer, message } = - handle2.event_rx.recv().await.unwrap(); - - println!("received message from {peer:?}: {:?}", std::str::from_utf8(&message)); - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut litep2p1, handle1) = make_litep2p(); + let (mut litep2p2, mut handle2) = make_litep2p(); + + let peer2 = *litep2p2.local_peer_id(); + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.add_known_address(peer2, std::iter::once(listen_address)); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + } + } + }); + + for message in vec![ + b"hello, world".to_vec(), + b"testing 123".to_vec(), + b"goodbye, world".to_vec(), + ] { + handle1 + .cmd_tx + .send(CustomProtocolCommand::SendMessage { + peer: peer2, + message, + }) + .await + .unwrap(); + + let CustomProtocolEvent::MessageReceived { peer, message } = + handle2.event_rx.recv().await.unwrap(); + + println!( + "received message from {peer:?}: {:?}", + std::str::from_utf8(&message) + ); + } } diff --git a/examples/echo_notification.rs b/examples/echo_notification.rs index 8a492a91..11842257 100644 --- a/examples/echo_notification.rs +++ b/examples/echo_notification.rs @@ -24,14 +24,14 @@ //! Run: `cargo run --example echo_notification` use litep2p::{ - config::ConfigBuilder, - protocol::notification::{ - ConfigBuilder as NotificationConfigBuilder, NotificationEvent, NotificationHandle, - ValidationResult, - }, - transport::quic::config::Config as QuicConfig, - types::protocol::ProtocolName, - Litep2p, PeerId, + config::ConfigBuilder, + protocol::notification::{ + ConfigBuilder as NotificationConfigBuilder, NotificationEvent, NotificationHandle, + ValidationResult, + }, + transport::quic::config::Config as QuicConfig, + types::protocol::ProtocolName, + Litep2p, PeerId, }; use futures::StreamExt; @@ -40,101 +40,101 @@ use std::time::Duration; /// event loop for the client async fn client_event_loop(mut litep2p: Litep2p, mut handle: NotificationHandle, peer: PeerId) { - // open substream to `peer` - // - // if `litep2p` is not connected to `peer` but it has at least one known address, - // `NotifcationHandle::open_substream()` will automatically dial `peer` - handle.open_substream(peer).await.unwrap(); + // open substream to `peer` + // + // if `litep2p` is not connected to `peer` but it has at least one known address, + // `NotifcationHandle::open_substream()` will automatically dial `peer` + handle.open_substream(peer).await.unwrap(); - // wait until the substream is opened - loop { - tokio::select! { - _ = litep2p.next_event() => {} - event = handle.next() => match event.unwrap() { - NotificationEvent::NotificationStreamOpened { .. } => break, - _ => {}, - } - } - } + // wait until the substream is opened + loop { + tokio::select! { + _ = litep2p.next_event() => {} + event = handle.next() => match event.unwrap() { + NotificationEvent::NotificationStreamOpened { .. } => break, + _ => {}, + } + } + } - // after the substream is open, send notification to server and print the response to stdout - loop { - tokio::select! { - _ = litep2p.next_event() => {} - event = handle.next() => match event.unwrap() { - NotificationEvent::NotificationReceived { peer, notification } => { - println!("received response from server ({peer:?}): {notification:?}"); - } - _ => {}, - }, - _ = tokio::time::sleep(Duration::from_secs(3)) => { - handle.send_sync_notification(peer, vec![1, 3, 3, 7]).unwrap(); - } - } - } + // after the substream is open, send notification to server and print the response to stdout + loop { + tokio::select! { + _ = litep2p.next_event() => {} + event = handle.next() => match event.unwrap() { + NotificationEvent::NotificationReceived { peer, notification } => { + println!("received response from server ({peer:?}): {notification:?}"); + } + _ => {}, + }, + _ = tokio::time::sleep(Duration::from_secs(3)) => { + handle.send_sync_notification(peer, vec![1, 3, 3, 7]).unwrap(); + } + } + } } /// event loop for the server async fn server_event_loop(mut litep2p: Litep2p, mut handle: NotificationHandle) { - loop { - tokio::select! { - _ = litep2p.next_event() => {} - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { peer, .. } => { - handle.send_validation_result(peer, ValidationResult::Accept); - } - NotificationEvent::NotificationReceived { peer, notification } => { - handle.send_async_notification(peer, notification.freeze().into()).await.unwrap(); - } - _ => {}, - }, - } - } + loop { + tokio::select! { + _ = litep2p.next_event() => {} + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { peer, .. } => { + handle.send_validation_result(peer, ValidationResult::Accept); + } + NotificationEvent::NotificationReceived { peer, notification } => { + handle.send_async_notification(peer, notification.freeze().into()).await.unwrap(); + } + _ => {}, + }, + } + } } /// helper function for creating `Litep2p` object fn make_litep2p() -> (Litep2p, NotificationHandle) { - // build notification config for the notification protocol - let (echo_config, echo_handle) = NotificationConfigBuilder::new(ProtocolName::from("/echo/1")) - .with_max_size(256) - .with_auto_accept_inbound(true) - .with_handshake(vec![1, 3, 3, 7]) - .build(); + // build notification config for the notification protocol + let (echo_config, echo_handle) = NotificationConfigBuilder::new(ProtocolName::from("/echo/1")) + .with_max_size(256) + .with_auto_accept_inbound(true) + .with_handshake(vec![1, 3, 3, 7]) + .build(); - // build `Litep2p` object and return it + notification handle - ( - Litep2p::new( - ConfigBuilder::new() - .with_quic(QuicConfig { - listen_addresses: vec!["/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap()], - ..Default::default() - }) - .with_notification_protocol(echo_config) - .build(), - ) - .unwrap(), - echo_handle, - ) + // build `Litep2p` object and return it + notification handle + ( + Litep2p::new( + ConfigBuilder::new() + .with_quic(QuicConfig { + listen_addresses: vec!["/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap()], + ..Default::default() + }) + .with_notification_protocol(echo_config) + .build(), + ) + .unwrap(), + echo_handle, + ) } #[tokio::main] async fn main() { - // build `Litep2p` objects for both peers - let (mut litep2p1, echo_handle1) = make_litep2p(); - let (litep2p2, echo_handle2) = make_litep2p(); + // build `Litep2p` objects for both peers + let (mut litep2p1, echo_handle1) = make_litep2p(); + let (litep2p2, echo_handle2) = make_litep2p(); - // get the first (and only) listen address for the second peer - // and add it as a known address for `litep2p1` - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - let peer = *litep2p2.local_peer_id(); + // get the first (and only) listen address for the second peer + // and add it as a known address for `litep2p1` + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + let peer = *litep2p2.local_peer_id(); - litep2p1.add_known_address(peer, vec![listen_address].into_iter()); + litep2p1.add_known_address(peer, vec![listen_address].into_iter()); - // start event loops for client and server - tokio::spawn(client_event_loop(litep2p1, echo_handle1, peer)); - tokio::spawn(server_event_loop(litep2p2, echo_handle2)); + // start event loops for client and server + tokio::spawn(client_event_loop(litep2p1, echo_handle1, peer)); + tokio::spawn(server_event_loop(litep2p2, echo_handle2)); - loop { - tokio::time::sleep(Duration::from_secs(10)).await; - } + loop { + tokio::time::sleep(Duration::from_secs(10)).await; + } } diff --git a/examples/gossiping.rs b/examples/gossiping.rs index e4a6e658..ca01d978 100644 --- a/examples/gossiping.rs +++ b/examples/gossiping.rs @@ -23,244 +23,263 @@ //! Run: `RUST_LOG=gossiping=info cargo run --example gossiping` use litep2p::{ - config::ConfigBuilder, - protocol::notification::{ - Config as NotificationConfig, ConfigBuilder as NotificationConfigBuilder, - NotificationEvent, NotificationHandle, ValidationResult, - }, - types::protocol::ProtocolName, - Litep2p, PeerId, + config::ConfigBuilder, + protocol::notification::{ + Config as NotificationConfig, ConfigBuilder as NotificationConfigBuilder, + NotificationEvent, NotificationHandle, ValidationResult, + }, + types::protocol::ProtocolName, + Litep2p, PeerId, }; use futures::StreamExt; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::{HashMap, HashSet}, - time::Duration, + collections::{HashMap, HashSet}, + time::Duration, }; /// Dummy transaction. #[derive(Debug, Hash, PartialEq, Eq, Clone)] struct Transaction { - tx: Vec, + tx: Vec, } /// Handle which allows communicating with [`TransactionProtocol`]. struct TransactionProtocolHandle { - tx: Sender, + tx: Sender, } impl TransactionProtocolHandle { - /// Create new [`TransactionProtocolHandle`]. - fn new() -> (Self, Receiver) { - let (tx, rx) = channel(64); - - (Self { tx }, rx) - } - - /// Announce transaction by sending it to the [`TransactionProtocol`] which will send - /// it to all peers who don't have it yet. - async fn announce_transaction(&self, tx: Transaction) { - self.tx.send(tx).await.unwrap(); - } + /// Create new [`TransactionProtocolHandle`]. + fn new() -> (Self, Receiver) { + let (tx, rx) = channel(64); + + (Self { tx }, rx) + } + + /// Announce transaction by sending it to the [`TransactionProtocol`] which will send + /// it to all peers who don't have it yet. + async fn announce_transaction(&self, tx: Transaction) { + self.tx.send(tx).await.unwrap(); + } } /// Transaction protocol. struct TransactionProtocol { - /// Notification handle used to send and receive notifications. - tx_handle: NotificationHandle, + /// Notification handle used to send and receive notifications. + tx_handle: NotificationHandle, - /// Handle for receiving transactions from user that should be sent to connected peers. - rx: Receiver, + /// Handle for receiving transactions from user that should be sent to connected peers. + rx: Receiver, - /// Connected peers. - peers: HashMap>, + /// Connected peers. + peers: HashMap>, - /// Seen transactions. - seen: HashSet, + /// Seen transactions. + seen: HashSet, } impl TransactionProtocol { - fn new() -> (Self, NotificationConfig, TransactionProtocolHandle) { - let (tx_config, tx_handle) = Self::init_tx_announce(); - let (handle, rx) = TransactionProtocolHandle::new(); - - (Self { tx_handle, rx, peers: HashMap::new(), seen: HashSet::new() }, tx_config, handle) - } - - /// Initialize notification protocol for transactions. - fn init_tx_announce() -> (NotificationConfig, NotificationHandle) { - NotificationConfigBuilder::new(ProtocolName::from("/notif/tx/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build() - } - - /// Poll next transaction from the protocol. - async fn next(&mut self) -> Option<(PeerId, Transaction)> { - loop { - tokio::select! { - event = self.tx_handle.next() => match event? { - NotificationEvent::ValidateSubstream { peer, .. } => { - tracing::info!("inbound substream received from {peer}"); - - self.tx_handle.send_validation_result(peer, ValidationResult::Accept); - } - NotificationEvent::NotificationStreamOpened { peer, .. } => { - tracing::info!("substream opened for {peer}"); - - self.peers.insert(peer, HashSet::new()); - } - NotificationEvent::NotificationStreamClosed { peer } => { - tracing::info!("substream closed for {peer}"); - - self.peers.remove(&peer); - } - NotificationEvent::NotificationReceived { peer, notification } => { - tracing::info!("transaction received from {peer}: {notification:?}"); - - // send transaction to all peers who don't have it yet - let notification = notification.freeze(); - - for (connected, txs) in &mut self.peers { - let not_seen = txs.insert(Transaction { tx: notification.clone().into() }); - if connected != &peer && not_seen { - self.tx_handle.send_sync_notification( - *connected, - notification.clone().into(), - ).unwrap(); - } - } - - if self.seen.insert(Transaction { tx: notification.clone().into() }) { - return Some((peer, Transaction { tx: notification.clone().into() })) - } - } - _ => {} - }, - tx = self.rx.recv() => match tx { - None => return None, - Some(transaction) => { - // send transaction to all peers who don't have it yet - self.seen.insert(transaction.clone()); - - for (peer, txs) in &mut self.peers { - if txs.insert(transaction.clone()) { - self.tx_handle.send_sync_notification( - *peer, - transaction.tx.clone(), - ).unwrap(); - } - } - } - } - } - } - } - - /// Start event loop for [`TransactionProtocol`]. - async fn run(mut self) { - loop { - match self.next().await { - Some((peer, tx)) => { - tracing::info!("received transaction from {peer}: {tx:?}"); - }, - None => return, - } - } - } + fn new() -> (Self, NotificationConfig, TransactionProtocolHandle) { + let (tx_config, tx_handle) = Self::init_tx_announce(); + let (handle, rx) = TransactionProtocolHandle::new(); + + ( + Self { + tx_handle, + rx, + peers: HashMap::new(), + seen: HashSet::new(), + }, + tx_config, + handle, + ) + } + + /// Initialize notification protocol for transactions. + fn init_tx_announce() -> (NotificationConfig, NotificationHandle) { + NotificationConfigBuilder::new(ProtocolName::from("/notif/tx/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build() + } + + /// Poll next transaction from the protocol. + async fn next(&mut self) -> Option<(PeerId, Transaction)> { + loop { + tokio::select! { + event = self.tx_handle.next() => match event? { + NotificationEvent::ValidateSubstream { peer, .. } => { + tracing::info!("inbound substream received from {peer}"); + + self.tx_handle.send_validation_result(peer, ValidationResult::Accept); + } + NotificationEvent::NotificationStreamOpened { peer, .. } => { + tracing::info!("substream opened for {peer}"); + + self.peers.insert(peer, HashSet::new()); + } + NotificationEvent::NotificationStreamClosed { peer } => { + tracing::info!("substream closed for {peer}"); + + self.peers.remove(&peer); + } + NotificationEvent::NotificationReceived { peer, notification } => { + tracing::info!("transaction received from {peer}: {notification:?}"); + + // send transaction to all peers who don't have it yet + let notification = notification.freeze(); + + for (connected, txs) in &mut self.peers { + let not_seen = txs.insert(Transaction { tx: notification.clone().into() }); + if connected != &peer && not_seen { + self.tx_handle.send_sync_notification( + *connected, + notification.clone().into(), + ).unwrap(); + } + } + + if self.seen.insert(Transaction { tx: notification.clone().into() }) { + return Some((peer, Transaction { tx: notification.clone().into() })) + } + } + _ => {} + }, + tx = self.rx.recv() => match tx { + None => return None, + Some(transaction) => { + // send transaction to all peers who don't have it yet + self.seen.insert(transaction.clone()); + + for (peer, txs) in &mut self.peers { + if txs.insert(transaction.clone()) { + self.tx_handle.send_sync_notification( + *peer, + transaction.tx.clone(), + ).unwrap(); + } + } + } + } + } + } + } + + /// Start event loop for [`TransactionProtocol`]. + async fn run(mut self) { + loop { + match self.next().await { + Some((peer, tx)) => { + tracing::info!("received transaction from {peer}: {tx:?}"); + } + None => return, + } + } + } } async fn await_substreams( - tx1: &mut TransactionProtocol, - tx2: &mut TransactionProtocol, - tx3: &mut TransactionProtocol, - tx4: &mut TransactionProtocol, + tx1: &mut TransactionProtocol, + tx2: &mut TransactionProtocol, + tx3: &mut TransactionProtocol, + tx4: &mut TransactionProtocol, ) { - loop { - tokio::select! { - _ = tx1.next() => {} - _ = tx2.next() => {} - _ = tx3.next() => {} - _ = tx4.next() => {} - _ = tokio::time::sleep(Duration::from_secs(2)) => { - if tx1.peers.len() == 1 && tx2.peers.len() == 3 && tx3.peers.len() == 1 && tx4.peers.len() == 1 { - return - } - } - } - } + loop { + tokio::select! { + _ = tx1.next() => {} + _ = tx2.next() => {} + _ = tx3.next() => {} + _ = tx4.next() => {} + _ = tokio::time::sleep(Duration::from_secs(2)) => { + if tx1.peers.len() == 1 && tx2.peers.len() == 3 && tx3.peers.len() == 1 && tx4.peers.len() == 1 { + return + } + } + } + } } /// Initialize peer with transaction protocol enabled. fn tx_peer() -> (Litep2p, TransactionProtocol, TransactionProtocolHandle) { - // initialize `TransctionProtocol` - let (tx, tx_announce_config, tx_handle) = TransactionProtocol::new(); + // initialize `TransctionProtocol` + let (tx, tx_announce_config, tx_handle) = TransactionProtocol::new(); - // build `Litep2pConfig` - let config = ConfigBuilder::new() - .with_tcp(Default::default()) - .with_notification_protocol(tx_announce_config) - .build(); + // build `Litep2pConfig` + let config = ConfigBuilder::new() + .with_tcp(Default::default()) + .with_notification_protocol(tx_announce_config) + .build(); - // create `Litep2p` object and start internal protocol handlers and the QUIC transport - let litep2p = Litep2p::new(config).unwrap(); + // create `Litep2p` object and start internal protocol handlers and the QUIC transport + let litep2p = Litep2p::new(config).unwrap(); - (litep2p, tx, tx_handle) + (litep2p, tx, tx_handle) } #[tokio::main] async fn main() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut litep2p1, mut tx1, tx_handle1) = tx_peer(); - let (mut litep2p2, mut tx2, _tx_handle2) = tx_peer(); - let (mut litep2p3, mut tx3, tx_handle3) = tx_peer(); - let (mut litep2p4, mut tx4, tx_handle4) = tx_peer(); - - tracing::info!("litep2p1: {}", litep2p1.local_peer_id()); - tracing::info!("litep2p2: {}", litep2p2.local_peer_id()); - tracing::info!("litep2p3: {}", litep2p3.local_peer_id()); - tracing::info!("litep2p4: {}", litep2p4.local_peer_id()); - - // establish connection to litep2p for all other litep2ps - let peer2 = *litep2p2.local_peer_id(); - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - - litep2p1.add_known_address(peer2, vec![listen_address.clone()].into_iter()); - litep2p3.add_known_address(peer2, vec![listen_address.clone()].into_iter()); - litep2p4.add_known_address(peer2, vec![listen_address].into_iter()); - - tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); - tokio::spawn(async move { while let Some(_) = litep2p2.next_event().await {} }); - tokio::spawn(async move { while let Some(_) = litep2p3.next_event().await {} }); - tokio::spawn(async move { while let Some(_) = litep2p4.next_event().await {} }); - - // open substreams - tx1.tx_handle.open_substream(peer2).await.unwrap(); - tx3.tx_handle.open_substream(peer2).await.unwrap(); - tx4.tx_handle.open_substream(peer2).await.unwrap(); - - // wait a moment for substream to open and start `TransactionProtocol` event loops - await_substreams(&mut tx1, &mut tx2, &mut tx3, &mut tx4).await; - - tokio::spawn(tx1.run()); - tokio::spawn(tx2.run()); - tokio::spawn(tx3.run()); - tokio::spawn(tx4.run()); - - // annouce three transactions over three different handles - tx_handle1.announce_transaction(Transaction { tx: vec![1, 2, 3, 4] }).await; - - tx_handle3.announce_transaction(Transaction { tx: vec![1, 3, 3, 7] }).await; - - tx_handle4 - .announce_transaction(Transaction { tx: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9] }) - .await; - - // allow protocols to process announced transactions before exiting - tokio::time::sleep(Duration::from_secs(3)).await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut litep2p1, mut tx1, tx_handle1) = tx_peer(); + let (mut litep2p2, mut tx2, _tx_handle2) = tx_peer(); + let (mut litep2p3, mut tx3, tx_handle3) = tx_peer(); + let (mut litep2p4, mut tx4, tx_handle4) = tx_peer(); + + tracing::info!("litep2p1: {}", litep2p1.local_peer_id()); + tracing::info!("litep2p2: {}", litep2p2.local_peer_id()); + tracing::info!("litep2p3: {}", litep2p3.local_peer_id()); + tracing::info!("litep2p4: {}", litep2p4.local_peer_id()); + + // establish connection to litep2p for all other litep2ps + let peer2 = *litep2p2.local_peer_id(); + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + + litep2p1.add_known_address(peer2, vec![listen_address.clone()].into_iter()); + litep2p3.add_known_address(peer2, vec![listen_address.clone()].into_iter()); + litep2p4.add_known_address(peer2, vec![listen_address].into_iter()); + + tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); + tokio::spawn(async move { while let Some(_) = litep2p2.next_event().await {} }); + tokio::spawn(async move { while let Some(_) = litep2p3.next_event().await {} }); + tokio::spawn(async move { while let Some(_) = litep2p4.next_event().await {} }); + + // open substreams + tx1.tx_handle.open_substream(peer2).await.unwrap(); + tx3.tx_handle.open_substream(peer2).await.unwrap(); + tx4.tx_handle.open_substream(peer2).await.unwrap(); + + // wait a moment for substream to open and start `TransactionProtocol` event loops + await_substreams(&mut tx1, &mut tx2, &mut tx3, &mut tx4).await; + + tokio::spawn(tx1.run()); + tokio::spawn(tx2.run()); + tokio::spawn(tx3.run()); + tokio::spawn(tx4.run()); + + // annouce three transactions over three different handles + tx_handle1 + .announce_transaction(Transaction { + tx: vec![1, 2, 3, 4], + }) + .await; + + tx_handle3 + .announce_transaction(Transaction { + tx: vec![1, 3, 3, 7], + }) + .await; + + tx_handle4 + .announce_transaction(Transaction { + tx: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + }) + .await; + + // allow protocols to process announced transactions before exiting + tokio::time::sleep(Duration::from_secs(3)).await; } diff --git a/examples/mnds_and_ping.rs b/examples/mnds_and_ping.rs index cc38740f..246e553b 100644 --- a/examples/mnds_and_ping.rs +++ b/examples/mnds_and_ping.rs @@ -22,12 +22,12 @@ //! calculating their PING time. use litep2p::{ - config::ConfigBuilder, - protocol::{ - libp2p::ping::{Config as PingConfig, PingEvent}, - mdns::{Config as MdnsConfig, MdnsEvent}, - }, - Litep2p, + config::ConfigBuilder, + protocol::{ + libp2p::ping::{Config as PingConfig, PingEvent}, + mdns::{Config as MdnsConfig, MdnsEvent}, + }, + Litep2p, }; use futures::{Stream, StreamExt}; @@ -37,60 +37,72 @@ use std::time::Duration; /// simple event loop which discovers peers over mDNS, /// establishes a connection to them and calculates the PING time async fn peer_event_loop( - mut litep2p: Litep2p, - mut ping_event_stream: Box + Send + Unpin>, - mut mdns_event_stream: Box + Send + Unpin>, + mut litep2p: Litep2p, + mut ping_event_stream: Box + Send + Unpin>, + mut mdns_event_stream: Box + Send + Unpin>, ) { - loop { - tokio::select! { - _ = litep2p.next_event() => {} - event = ping_event_stream.next() => match event.unwrap() { - PingEvent::Ping { peer, ping } => { - println!("ping received from {peer:?}: {ping:?}"); - } - }, - event = mdns_event_stream.next() => match event.unwrap() { - MdnsEvent::Discovered(addresses) => { - litep2p.dial_address(addresses[0].clone()).await.unwrap(); - } - } - } - } + loop { + tokio::select! { + _ = litep2p.next_event() => {} + event = ping_event_stream.next() => match event.unwrap() { + PingEvent::Ping { peer, ping } => { + println!("ping received from {peer:?}: {ping:?}"); + } + }, + event = mdns_event_stream.next() => match event.unwrap() { + MdnsEvent::Discovered(addresses) => { + litep2p.dial_address(addresses[0].clone()).await.unwrap(); + } + } + } + } } /// helper function for creating `Litep2p` object fn make_litep2p() -> ( - Litep2p, - Box + Send + Unpin>, - Box + Send + Unpin>, + Litep2p, + Box + Send + Unpin>, + Box + Send + Unpin>, ) { - // initialize IPFS ping and mDNS - let (ping_config, ping_event_stream) = PingConfig::default(); - let (mdns_config, mdns_event_stream) = MdnsConfig::new(Duration::from_secs(30)); + // initialize IPFS ping and mDNS + let (ping_config, ping_event_stream) = PingConfig::default(); + let (mdns_config, mdns_event_stream) = MdnsConfig::new(Duration::from_secs(30)); - // build `Litep2p`, passing in configurations for IPFS and mDNS - let litep2p_config = ConfigBuilder::new() - // `litep2p` will bind to `/ip6/::1/tcp/0` by default - .with_tcp(Default::default()) - .with_libp2p_ping(ping_config) - .with_mdns(mdns_config) - .build(); + // build `Litep2p`, passing in configurations for IPFS and mDNS + let litep2p_config = ConfigBuilder::new() + // `litep2p` will bind to `/ip6/::1/tcp/0` by default + .with_tcp(Default::default()) + .with_libp2p_ping(ping_config) + .with_mdns(mdns_config) + .build(); - // build `Litep2p` and return it + event streams - (Litep2p::new(litep2p_config).unwrap(), ping_event_stream, mdns_event_stream) + // build `Litep2p` and return it + event streams + ( + Litep2p::new(litep2p_config).unwrap(), + ping_event_stream, + mdns_event_stream, + ) } #[tokio::main] async fn main() { - // initialize `Litep2p` objects for the peers - let (litep2p1, ping_event_stream1, mdns_event_stream1) = make_litep2p(); - let (litep2p2, ping_event_stream2, mdns_event_stream2) = make_litep2p(); + // initialize `Litep2p` objects for the peers + let (litep2p1, ping_event_stream1, mdns_event_stream1) = make_litep2p(); + let (litep2p2, ping_event_stream2, mdns_event_stream2) = make_litep2p(); - // starts separate tasks for the first and second peer - tokio::spawn(peer_event_loop(litep2p1, ping_event_stream1, mdns_event_stream1)); - tokio::spawn(peer_event_loop(litep2p2, ping_event_stream2, mdns_event_stream2)); + // starts separate tasks for the first and second peer + tokio::spawn(peer_event_loop( + litep2p1, + ping_event_stream1, + mdns_event_stream1, + )); + tokio::spawn(peer_event_loop( + litep2p2, + ping_event_stream2, + mdns_event_stream2, + )); - loop { - tokio::time::sleep(Duration::from_secs(10)).await; - } + loop { + tokio::time::sleep(Duration::from_secs(10)).await; + } } diff --git a/examples/syncing.rs b/examples/syncing.rs index faa72ee0..9d79f9c9 100644 --- a/examples/syncing.rs +++ b/examples/syncing.rs @@ -22,112 +22,121 @@ //! to implement, e.g, a syncing protocol using notification and request-response protocols use litep2p::{ - config::ConfigBuilder, - protocol::{ - notification::{ - Config as NotificationConfig, ConfigBuilder as NotificationConfigBuilder, - NotificationHandle, - }, - request_response::{ - Config as RequestResponseConfig, ConfigBuilder as RequestResponseConfigBuilder, - RequestResponseHandle, - }, - }, - transport::quic::config::Config as QuicConfig, - types::protocol::ProtocolName, - Litep2p, + config::ConfigBuilder, + protocol::{ + notification::{ + Config as NotificationConfig, ConfigBuilder as NotificationConfigBuilder, + NotificationHandle, + }, + request_response::{ + Config as RequestResponseConfig, ConfigBuilder as RequestResponseConfigBuilder, + RequestResponseHandle, + }, + }, + transport::quic::config::Config as QuicConfig, + types::protocol::ProtocolName, + Litep2p, }; use futures::StreamExt; /// Object responsible for syncing the blockchain. struct SyncingEngine { - /// Notification handle used to send and receive notifications. - block_announce_handle: NotificationHandle, + /// Notification handle used to send and receive notifications. + block_announce_handle: NotificationHandle, - /// Request-response handle used to send and receive block requests/responses. - block_sync_handle: RequestResponseHandle, + /// Request-response handle used to send and receive block requests/responses. + block_sync_handle: RequestResponseHandle, - /// Request-response handle used to send and receive state requests/responses. - state_sync_handle: RequestResponseHandle, + /// Request-response handle used to send and receive state requests/responses. + state_sync_handle: RequestResponseHandle, } impl SyncingEngine { - /// Create new [`SyncingEngine`]. - fn new() -> (Self, NotificationConfig, RequestResponseConfig, RequestResponseConfig) { - let (block_announce_config, block_announce_handle) = Self::init_block_announce(); - let (block_sync_config, block_sync_handle) = Self::init_block_sync(); - let (state_sync_config, state_sync_handle) = Self::init_state_sync(); - - ( - Self { block_announce_handle, block_sync_handle, state_sync_handle }, - block_announce_config, - block_sync_config, - state_sync_config, - ) - } - - /// Initialize notification protocol for block announcements - fn init_block_announce() -> (NotificationConfig, NotificationHandle) { - NotificationConfigBuilder::new(ProtocolName::from("/notif/block-announce/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build() - } - - /// Initialize request-response protocol for block syncing. - fn init_block_sync() -> (RequestResponseConfig, RequestResponseHandle) { - RequestResponseConfigBuilder::new(ProtocolName::from("/sync/block/1")) - .with_max_size(1024 * 1024) - .build() - } - - /// Initialize request-response protocol for state syncing. - fn init_state_sync() -> (RequestResponseConfig, RequestResponseHandle) { - RequestResponseConfigBuilder::new(ProtocolName::from("/sync/state/1")) - .with_max_size(1024 * 1024) - .build() - } - - /// Start event loop for [`SyncingEngine`]. - async fn run(mut self) { - loop { - tokio::select! { - _ = self.block_announce_handle.next() => {} - _ = self.block_sync_handle.next() => {} - _ = self.state_sync_handle.next() => {} - } - } - } + /// Create new [`SyncingEngine`]. + fn new() -> ( + Self, + NotificationConfig, + RequestResponseConfig, + RequestResponseConfig, + ) { + let (block_announce_config, block_announce_handle) = Self::init_block_announce(); + let (block_sync_config, block_sync_handle) = Self::init_block_sync(); + let (state_sync_config, state_sync_handle) = Self::init_state_sync(); + + ( + Self { + block_announce_handle, + block_sync_handle, + state_sync_handle, + }, + block_announce_config, + block_sync_config, + state_sync_config, + ) + } + + /// Initialize notification protocol for block announcements + fn init_block_announce() -> (NotificationConfig, NotificationHandle) { + NotificationConfigBuilder::new(ProtocolName::from("/notif/block-announce/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build() + } + + /// Initialize request-response protocol for block syncing. + fn init_block_sync() -> (RequestResponseConfig, RequestResponseHandle) { + RequestResponseConfigBuilder::new(ProtocolName::from("/sync/block/1")) + .with_max_size(1024 * 1024) + .build() + } + + /// Initialize request-response protocol for state syncing. + fn init_state_sync() -> (RequestResponseConfig, RequestResponseHandle) { + RequestResponseConfigBuilder::new(ProtocolName::from("/sync/state/1")) + .with_max_size(1024 * 1024) + .build() + } + + /// Start event loop for [`SyncingEngine`]. + async fn run(mut self) { + loop { + tokio::select! { + _ = self.block_announce_handle.next() => {} + _ = self.block_sync_handle.next() => {} + _ = self.state_sync_handle.next() => {} + } + } + } } #[tokio::main] async fn main() { - // create `SyncingEngine` and get configs for the protocols that it will use. - let (engine, block_announce_config, block_sync_config, state_sync_config) = - SyncingEngine::new(); - - // build `Litep2pConfig` - let config = ConfigBuilder::new() - .with_quic(QuicConfig { - listen_addresses: vec!["/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap()], - ..Default::default() - }) - .with_notification_protocol(block_announce_config) - .with_request_response_protocol(block_sync_config) - .with_request_response_protocol(state_sync_config) - .build(); - - // create `Litep2p` object and start internal protocol handlers and the QUIC transport - let mut litep2p = Litep2p::new(config).unwrap(); - - // spawn `SyncingEngine` in the background - tokio::spawn(engine.run()); - - // poll `litep2p` to allow connection-related activity to make progress - loop { - match litep2p.next_event().await.unwrap() { - _ => {}, - } - } + // create `SyncingEngine` and get configs for the protocols that it will use. + let (engine, block_announce_config, block_sync_config, state_sync_config) = + SyncingEngine::new(); + + // build `Litep2pConfig` + let config = ConfigBuilder::new() + .with_quic(QuicConfig { + listen_addresses: vec!["/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap()], + ..Default::default() + }) + .with_notification_protocol(block_announce_config) + .with_request_response_protocol(block_sync_config) + .with_request_response_protocol(state_sync_config) + .build(); + + // create `Litep2p` object and start internal protocol handlers and the QUIC transport + let mut litep2p = Litep2p::new(config).unwrap(); + + // spawn `SyncingEngine` in the background + tokio::spawn(engine.run()); + + // poll `litep2p` to allow connection-related activity to make progress + loop { + match litep2p.next_event().await.unwrap() { + _ => {} + } + } } diff --git a/rustfmt.toml b/rustfmt.toml index 2778ce43..30af9121 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,8 +1,6 @@ # Basic edition = "2021" -hard_tabs = true max_width = 100 -use_small_heuristics = "Max" # Imports imports_granularity = "Crate" @@ -14,14 +12,8 @@ newline_style = "Unix" # Misc chain_width = 80 spaces_around_ranges = false -binop_separator = "Back" -reorder_impl_items = false -match_arm_leading_pipes = "Preserve" match_arm_blocks = false -match_block_trailing_comma = true trailing_comma = "Vertical" -trailing_semicolon = false -use_field_init_shorthand = true # Format comments comment_width = 100 diff --git a/src/bandwidth.rs b/src/bandwidth.rs index aa28dfbd..4895ad20 100644 --- a/src/bandwidth.rs +++ b/src/bandwidth.rs @@ -21,18 +21,18 @@ //! Bandwidth sinks for metering inbound/outbound bytes. use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, + atomic::{AtomicUsize, Ordering}, + Arc, }; /// Inner bandwidth sink #[derive(Debug)] struct InnerBandwidthSink { - /// Number of inbound bytes. - inbound: AtomicUsize, + /// Number of inbound bytes. + inbound: AtomicUsize, - /// Number of outbound bytes. - outbound: AtomicUsize, + /// Number of outbound bytes. + outbound: AtomicUsize, } /// Bandwidth sink which provides metering for inbound/outbound byte usage. @@ -44,47 +44,47 @@ struct InnerBandwidthSink { pub struct BandwidthSink(Arc); impl BandwidthSink { - /// Create new [`BandwidthSink`]. - pub(crate) fn new() -> Self { - Self(Arc::new(InnerBandwidthSink { - inbound: AtomicUsize::new(0usize), - outbound: AtomicUsize::new(0usize), - })) - } + /// Create new [`BandwidthSink`]. + pub(crate) fn new() -> Self { + Self(Arc::new(InnerBandwidthSink { + inbound: AtomicUsize::new(0usize), + outbound: AtomicUsize::new(0usize), + })) + } - /// Increase the amount of inbound bytes. - pub(crate) fn increase_inbound(&self, bytes: usize) { - let _ = self.0.inbound.fetch_add(bytes, Ordering::Relaxed); - } + /// Increase the amount of inbound bytes. + pub(crate) fn increase_inbound(&self, bytes: usize) { + let _ = self.0.inbound.fetch_add(bytes, Ordering::Relaxed); + } - /// Increse the amount of outbound bytes. - pub(crate) fn increase_outbound(&self, bytes: usize) { - let _ = self.0.outbound.fetch_add(bytes, Ordering::Relaxed); - } + /// Increse the amount of outbound bytes. + pub(crate) fn increase_outbound(&self, bytes: usize) { + let _ = self.0.outbound.fetch_add(bytes, Ordering::Relaxed); + } - /// Get total the number of bytes received. - pub fn inbound(&self) -> usize { - self.0.inbound.load(Ordering::Relaxed) - } + /// Get total the number of bytes received. + pub fn inbound(&self) -> usize { + self.0.inbound.load(Ordering::Relaxed) + } - /// Get total the nubmer of bytes sent. - pub fn outbound(&self) -> usize { - self.0.outbound.load(Ordering::Relaxed) - } + /// Get total the nubmer of bytes sent. + pub fn outbound(&self) -> usize { + self.0.outbound.load(Ordering::Relaxed) + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn verify_bandwidth() { - let sink = BandwidthSink::new(); + #[test] + fn verify_bandwidth() { + let sink = BandwidthSink::new(); - sink.increase_inbound(1337usize); - sink.increase_outbound(1338usize); + sink.increase_inbound(1337usize); + sink.increase_outbound(1338usize); - assert_eq!(sink.inbound(), 1337usize); - assert_eq!(sink.outbound(), 1338usize); - } + assert_eq!(sink.inbound(), 1337usize); + assert_eq!(sink.outbound(), 1338usize); + } } diff --git a/src/codec/identity.rs b/src/codec/identity.rs index c4ab9832..10ed8fa6 100644 --- a/src/codec/identity.rs +++ b/src/codec/identity.rs @@ -27,97 +27,100 @@ use tokio_util::codec::{Decoder, Encoder}; /// Identity codec. pub struct Identity { - payload_len: usize, + payload_len: usize, } impl Identity { - /// Create new [`Identity`] codec. - pub fn new(payload_len: usize) -> Self { - assert!(payload_len != 0); - - Self { payload_len } - } - - /// Encode `payload` using identity codec. - pub fn encode>(payload: T) -> crate::Result> { - let payload: Bytes = payload.into(); - Ok(payload.into()) - } + /// Create new [`Identity`] codec. + pub fn new(payload_len: usize) -> Self { + assert!(payload_len != 0); + + Self { payload_len } + } + + /// Encode `payload` using identity codec. + pub fn encode>(payload: T) -> crate::Result> { + let payload: Bytes = payload.into(); + Ok(payload.into()) + } } impl Decoder for Identity { - type Item = BytesMut; - type Error = Error; + type Item = BytesMut; + type Error = Error; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if src.is_empty() { - return Ok(None); - } + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + if src.is_empty() { + return Ok(None); + } - Ok(Some(src.split_to(self.payload_len))) - } + Ok(Some(src.split_to(self.payload_len))) + } } impl Encoder for Identity { - type Error = Error; + type Error = Error; - fn encode(&mut self, item: Bytes, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { - if item.len() > self.payload_len || item.is_empty() { - return Err(Error::InvalidData); - } + fn encode(&mut self, item: Bytes, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + if item.len() > self.payload_len || item.is_empty() { + return Err(Error::InvalidData); + } - dst.put_slice(item.as_ref()); - Ok(()) - } + dst.put_slice(item.as_ref()); + Ok(()) + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn encoding_works() { - let mut codec = Identity::new(48); - let mut out_buf = BytesMut::with_capacity(32); - let bytes = Bytes::from(vec![0u8; 48]); - - assert!(codec.encode(bytes.clone(), &mut out_buf).is_ok()); - assert_eq!(out_buf.freeze(), bytes); - } - - #[test] - fn decoding_works() { - let mut codec = Identity::new(64); - let bytes = vec![3u8; 64]; - let copy = bytes.clone(); - let mut bytes = BytesMut::from(&bytes[..]); - - let decoded = codec.decode(&mut bytes).unwrap().unwrap(); - assert_eq!(decoded, copy); - } - - #[test] - fn empty_encode() { - let mut codec = Identity::new(32); - let mut out_buf = BytesMut::with_capacity(32); - assert!(codec.encode(Bytes::new(), &mut out_buf).is_err()); - } - - #[test] - fn decode_encode() { - let mut codec = Identity::new(32); - assert!(codec.decode(&mut BytesMut::new()).unwrap().is_none()); - } - - #[test] - fn direct_encoding_works() { - assert_eq!(Identity::encode(vec![1, 3, 3, 7]).unwrap(), vec![1, 3, 3, 7]); - } - - #[test] - #[should_panic] - #[cfg(debug_assertions)] - fn empty_identity_codec() { - let _codec = Identity::new(0usize); - } + use super::*; + + #[test] + fn encoding_works() { + let mut codec = Identity::new(48); + let mut out_buf = BytesMut::with_capacity(32); + let bytes = Bytes::from(vec![0u8; 48]); + + assert!(codec.encode(bytes.clone(), &mut out_buf).is_ok()); + assert_eq!(out_buf.freeze(), bytes); + } + + #[test] + fn decoding_works() { + let mut codec = Identity::new(64); + let bytes = vec![3u8; 64]; + let copy = bytes.clone(); + let mut bytes = BytesMut::from(&bytes[..]); + + let decoded = codec.decode(&mut bytes).unwrap().unwrap(); + assert_eq!(decoded, copy); + } + + #[test] + fn empty_encode() { + let mut codec = Identity::new(32); + let mut out_buf = BytesMut::with_capacity(32); + assert!(codec.encode(Bytes::new(), &mut out_buf).is_err()); + } + + #[test] + fn decode_encode() { + let mut codec = Identity::new(32); + assert!(codec.decode(&mut BytesMut::new()).unwrap().is_none()); + } + + #[test] + fn direct_encoding_works() { + assert_eq!( + Identity::encode(vec![1, 3, 3, 7]).unwrap(), + vec![1, 3, 3, 7] + ); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] + fn empty_identity_codec() { + let _codec = Identity::new(0usize); + } } diff --git a/src/codec/mod.rs b/src/codec/mod.rs index d9e41129..3604c023 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -26,12 +26,12 @@ pub mod unsigned_varint; /// Supported protocol codecs. #[derive(Debug, Copy, Clone)] pub enum ProtocolCodec { - /// Identity codec where the argument denotes the payload size. - Identity(usize), + /// Identity codec where the argument denotes the payload size. + Identity(usize), - /// Unsigned varint where the argument denotes the maximum message size, if specified. - UnsignedVarint(Option), + /// Unsigned varint where the argument denotes the maximum message size, if specified. + UnsignedVarint(Option), - /// Protocol doens't need framing for its messages or is using a custom codec. - Unspecified, + /// Protocol doens't need framing for its messages or is using a custom codec. + Unspecified, } diff --git a/src/codec/unsigned_varint.rs b/src/codec/unsigned_varint.rs index 34be8b8d..0c62070c 100644 --- a/src/codec/unsigned_varint.rs +++ b/src/codec/unsigned_varint.rs @@ -28,114 +28,114 @@ use unsigned_varint::codec::UviBytes; /// Unsigned varint codec. pub struct UnsignedVarint { - codec: UviBytes, + codec: UviBytes, } impl UnsignedVarint { - /// Create new [`UnsignedVarint`] codec. - pub fn new(max_size: Option) -> Self { - let mut codec = UviBytes::::default(); + /// Create new [`UnsignedVarint`] codec. + pub fn new(max_size: Option) -> Self { + let mut codec = UviBytes::::default(); - if let Some(max_size) = max_size { - codec.set_max_len(max_size); - } + if let Some(max_size) = max_size { + codec.set_max_len(max_size); + } - Self { codec } - } + Self { codec } + } - /// Set maximum size for encoded/decodes values. - pub fn with_max_size(max_size: usize) -> Self { - let mut codec = UviBytes::::default(); - codec.set_max_len(max_size); + /// Set maximum size for encoded/decodes values. + pub fn with_max_size(max_size: usize) -> Self { + let mut codec = UviBytes::::default(); + codec.set_max_len(max_size); - Self { codec } - } + Self { codec } + } - /// Encode `payload` using `unsigned-varint`. - pub fn encode>(payload: T) -> crate::Result> { - let payload: Bytes = payload.into(); + /// Encode `payload` using `unsigned-varint`. + pub fn encode>(payload: T) -> crate::Result> { + let payload: Bytes = payload.into(); - assert!(payload.len() <= u32::MAX as usize); + assert!(payload.len() <= u32::MAX as usize); - let mut bytes = BytesMut::with_capacity(payload.len() + 4); - let mut codec = Self::new(None); - codec.encode(payload.into(), &mut bytes)?; + let mut bytes = BytesMut::with_capacity(payload.len() + 4); + let mut codec = Self::new(None); + codec.encode(payload.into(), &mut bytes)?; - Ok(bytes.into()) - } + Ok(bytes.into()) + } - /// Decode `payload` into `BytesMut`. - pub fn decode(payload: &mut BytesMut) -> crate::Result { - Ok(UviBytes::::default().decode(payload)?.ok_or(Error::InvalidData)?) - } + /// Decode `payload` into `BytesMut`. + pub fn decode(payload: &mut BytesMut) -> crate::Result { + Ok(UviBytes::::default().decode(payload)?.ok_or(Error::InvalidData)?) + } } impl Decoder for UnsignedVarint { - type Item = BytesMut; - type Error = Error; + type Item = BytesMut; + type Error = Error; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - self.codec.decode(src).map_err(From::from) - } + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + self.codec.decode(src).map_err(From::from) + } } impl Encoder for UnsignedVarint { - type Error = Error; + type Error = Error; - fn encode(&mut self, item: Bytes, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { - self.codec.encode(item, dst).map_err(From::from) - } + fn encode(&mut self, item: Bytes, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + self.codec.encode(item, dst).map_err(From::from) + } } #[cfg(test)] mod tests { - use super::{Bytes, BytesMut, UnsignedVarint}; - - #[test] - fn max_size_respected() { - let mut codec = UnsignedVarint::with_max_size(1024); - - { - use tokio_util::codec::Encoder; - - let bytes_to_encode: Bytes = vec![0u8; 1024].into(); - let mut out_bytes = BytesMut::with_capacity(2048); - assert!(codec.encode(bytes_to_encode, &mut out_bytes).is_ok()); - } - - { - use tokio_util::codec::Encoder; - - let bytes_to_encode: Bytes = vec![1u8; 1025].into(); - let mut out_bytes = BytesMut::with_capacity(2048); - assert!(codec.encode(bytes_to_encode, &mut out_bytes).is_err()); - } - } - - #[test] - fn encode_decode_works() { - let encoded1 = UnsignedVarint::encode(vec![0u8; 512]).unwrap(); - let mut encoded2 = { - use tokio_util::codec::Encoder; - - let mut codec = UnsignedVarint::with_max_size(512); - let bytes_to_encode: Bytes = vec![0u8; 512].into(); - let mut out_bytes = BytesMut::with_capacity(2048); - codec.encode(bytes_to_encode, &mut out_bytes).unwrap(); - out_bytes - }; - - assert_eq!(encoded1, encoded2); - - let decoded1 = UnsignedVarint::decode(&mut encoded2).unwrap(); - let decoded2 = { - use tokio_util::codec::Decoder; - - let mut codec = UnsignedVarint::with_max_size(512); - let mut encoded1 = BytesMut::from(&encoded1[..]); - codec.decode(&mut encoded1).unwrap().unwrap() - }; - - assert_eq!(decoded1, decoded2); - } + use super::{Bytes, BytesMut, UnsignedVarint}; + + #[test] + fn max_size_respected() { + let mut codec = UnsignedVarint::with_max_size(1024); + + { + use tokio_util::codec::Encoder; + + let bytes_to_encode: Bytes = vec![0u8; 1024].into(); + let mut out_bytes = BytesMut::with_capacity(2048); + assert!(codec.encode(bytes_to_encode, &mut out_bytes).is_ok()); + } + + { + use tokio_util::codec::Encoder; + + let bytes_to_encode: Bytes = vec![1u8; 1025].into(); + let mut out_bytes = BytesMut::with_capacity(2048); + assert!(codec.encode(bytes_to_encode, &mut out_bytes).is_err()); + } + } + + #[test] + fn encode_decode_works() { + let encoded1 = UnsignedVarint::encode(vec![0u8; 512]).unwrap(); + let mut encoded2 = { + use tokio_util::codec::Encoder; + + let mut codec = UnsignedVarint::with_max_size(512); + let bytes_to_encode: Bytes = vec![0u8; 512].into(); + let mut out_bytes = BytesMut::with_capacity(2048); + codec.encode(bytes_to_encode, &mut out_bytes).unwrap(); + out_bytes + }; + + assert_eq!(encoded1, encoded2); + + let decoded1 = UnsignedVarint::decode(&mut encoded2).unwrap(); + let decoded2 = { + use tokio_util::codec::Decoder; + + let mut codec = UnsignedVarint::with_max_size(512); + let mut encoded1 = BytesMut::from(&encoded1[..]); + codec.decode(&mut encoded1).unwrap().unwrap() + }; + + assert_eq!(decoded1, decoded2); + } } diff --git a/src/config.rs b/src/config.rs index 21f917be..4b5df79b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,20 +21,20 @@ //! [`Litep2p`](`crate::Litep2p`) configuration. use crate::{ - crypto::ed25519::Keypair, - executor::{DefaultExecutor, Executor}, - protocol::{ - libp2p::{bitswap, identify, kademlia, ping}, - mdns::Config as MdnsConfig, - notification, request_response, UserProtocol, - }, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - webrtc::config::Config as WebRtcConfig, websocket::config::Config as WebSocketConfig, - MAX_PARALLEL_DIALS, - }, - types::protocol::ProtocolName, - PeerId, + crypto::ed25519::Keypair, + executor::{DefaultExecutor, Executor}, + protocol::{ + libp2p::{bitswap, identify, kademlia, ping}, + mdns::Config as MdnsConfig, + notification, request_response, UserProtocol, + }, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + webrtc::config::Config as WebRtcConfig, websocket::config::Config as WebSocketConfig, + MAX_PARALLEL_DIALS, + }, + types::protocol::ProtocolName, + PeerId, }; use multiaddr::Multiaddr; @@ -44,274 +44,274 @@ use std::{collections::HashMap, sync::Arc}; /// Connection role. #[derive(Debug, Copy, Clone)] pub enum Role { - /// Dialer. - Dialer, + /// Dialer. + Dialer, - /// Listener. - Listener, + /// Listener. + Listener, } impl From for crate::yamux::Mode { - fn from(value: Role) -> Self { - match value { - Role::Dialer => crate::yamux::Mode::Client, - Role::Listener => crate::yamux::Mode::Server, - } - } + fn from(value: Role) -> Self { + match value { + Role::Dialer => crate::yamux::Mode::Client, + Role::Listener => crate::yamux::Mode::Server, + } + } } /// Configuration builder for [`Litep2p`](`crate::Litep2p`). pub struct ConfigBuilder { - // TCP transport configuration. - tcp: Option, + // TCP transport configuration. + tcp: Option, - /// QUIC transport config. - quic: Option, + /// QUIC transport config. + quic: Option, - /// WebRTC transport config. - webrtc: Option, + /// WebRTC transport config. + webrtc: Option, - /// WebSocket transport config. - websocket: Option, + /// WebSocket transport config. + websocket: Option, - /// Keypair. - keypair: Option, + /// Keypair. + keypair: Option, - /// Ping protocol config. - ping: Option, + /// Ping protocol config. + ping: Option, - /// Identify protocol config. - identify: Option, + /// Identify protocol config. + identify: Option, - /// Kademlia protocol config. - kademlia: Option, + /// Kademlia protocol config. + kademlia: Option, - /// Bitswap protocol config. - bitswap: Option, + /// Bitswap protocol config. + bitswap: Option, - /// Notification protocols. - notification_protocols: HashMap, + /// Notification protocols. + notification_protocols: HashMap, - /// Request-response protocols. - request_response_protocols: HashMap, + /// Request-response protocols. + request_response_protocols: HashMap, - /// User protocols. - user_protocols: HashMap>, + /// User protocols. + user_protocols: HashMap>, - /// mDNS configuration. - mdns: Option, + /// mDNS configuration. + mdns: Option, - /// Known addresess. - known_addresses: Vec<(PeerId, Vec)>, + /// Known addresess. + known_addresses: Vec<(PeerId, Vec)>, - /// Executor for running futures. - executor: Option>, + /// Executor for running futures. + executor: Option>, - /// Maximum number of parallel dial attempts. - max_parallel_dials: usize, + /// Maximum number of parallel dial attempts. + max_parallel_dials: usize, } impl ConfigBuilder { - /// Create empty [`ConfigBuilder`]. - pub fn new() -> Self { - Self { - tcp: None, - quic: None, - webrtc: None, - websocket: None, - keypair: None, - ping: None, - identify: None, - kademlia: None, - bitswap: None, - mdns: None, - executor: None, - max_parallel_dials: MAX_PARALLEL_DIALS, - user_protocols: HashMap::new(), - notification_protocols: HashMap::new(), - request_response_protocols: HashMap::new(), - known_addresses: Vec::new(), - } - } - - /// Add TCP transport configuration, enabling the transport. - pub fn with_tcp(mut self, config: TcpConfig) -> Self { - self.tcp = Some(config); - self - } - - /// Add QUIC transport configuration, enabling the transport. - pub fn with_quic(mut self, config: QuicConfig) -> Self { - self.quic = Some(config); - self - } - - /// Add WebRTC transport configuration, enabling the transport. - pub fn with_webrtc(mut self, config: WebRtcConfig) -> Self { - self.webrtc = Some(config); - self - } - - /// Add WebSocket transport configuration, enabling the transport. - pub fn with_websocket(mut self, config: WebSocketConfig) -> Self { - self.websocket = Some(config); - self - } - - /// Add keypair. - /// - /// If no keypair is specified, litep2p creates a new keypair. - pub fn with_keypair(mut self, keypair: Keypair) -> Self { - self.keypair = Some(keypair); - self - } - - /// Enable notification protocol. - pub fn with_notification_protocol(mut self, config: notification::Config) -> Self { - self.notification_protocols.insert(config.protocol_name().clone(), config); - self - } - - /// Enable IPFS Ping protocol. - pub fn with_libp2p_ping(mut self, config: ping::Config) -> Self { - self.ping = Some(config); - self - } - - /// Enable IPFS Identify protocol. - pub fn with_libp2p_identify(mut self, config: identify::Config) -> Self { - self.identify = Some(config); - self - } - - /// Enable IPFS Kademlia protocol. - pub fn with_libp2p_kademlia(mut self, config: kademlia::Config) -> Self { - self.kademlia = Some(config); - self - } - - /// Enable IPFS Bitswap protocol. - pub fn with_libp2p_bitswap(mut self, config: bitswap::Config) -> Self { - self.bitswap = Some(config); - self - } - - /// Enable request-response protocol. - pub fn with_request_response_protocol(mut self, config: request_response::Config) -> Self { - self.request_response_protocols.insert(config.protocol_name().clone(), config); - self - } - - /// Enable user protocol. - pub fn with_user_protocol(mut self, protocol: Box) -> Self { - self.user_protocols.insert(protocol.protocol(), protocol); - self - } - - /// Enable mDNS for peer discoveries in the local network. - pub fn with_mdns(mut self, config: MdnsConfig) -> Self { - self.mdns = Some(config); - self - } - - /// Add known address(es) for one or more peers. - pub fn with_known_addresses( - mut self, - addresses: impl Iterator)>, - ) -> Self { - self.known_addresses = addresses.collect(); - self - } - - /// Add executor for running futures spawned by `litep2p`. - /// - /// If no executor is specified, `litep2p` defaults to calling `tokio::spawn()`. - pub fn with_executor(mut self, executor: Arc) -> Self { - self.executor = Some(executor); - self - } - - /// How many addresses should litep2p attempt to dial in parallel. - pub fn with_max_parallel_dials(mut self, max_parallel_dials: usize) -> Self { - self.max_parallel_dials = max_parallel_dials; - self - } - - /// Build [`Litep2pConfig`]. - pub fn build(mut self) -> Litep2pConfig { - let keypair = match self.keypair { - Some(keypair) => keypair, - None => Keypair::generate(), - }; - - Litep2pConfig { - keypair, - tcp: self.tcp.take(), - mdns: self.mdns.take(), - quic: self.quic.take(), - webrtc: self.webrtc.take(), - websocket: self.websocket.take(), - ping: self.ping.take(), - identify: self.identify.take(), - kademlia: self.kademlia.take(), - bitswap: self.bitswap.take(), - max_parallel_dials: self.max_parallel_dials, - executor: self.executor.map_or(Arc::new(DefaultExecutor {}), |executor| executor), - user_protocols: self.user_protocols, - notification_protocols: self.notification_protocols, - request_response_protocols: self.request_response_protocols, - known_addresses: self.known_addresses, - } - } + /// Create empty [`ConfigBuilder`]. + pub fn new() -> Self { + Self { + tcp: None, + quic: None, + webrtc: None, + websocket: None, + keypair: None, + ping: None, + identify: None, + kademlia: None, + bitswap: None, + mdns: None, + executor: None, + max_parallel_dials: MAX_PARALLEL_DIALS, + user_protocols: HashMap::new(), + notification_protocols: HashMap::new(), + request_response_protocols: HashMap::new(), + known_addresses: Vec::new(), + } + } + + /// Add TCP transport configuration, enabling the transport. + pub fn with_tcp(mut self, config: TcpConfig) -> Self { + self.tcp = Some(config); + self + } + + /// Add QUIC transport configuration, enabling the transport. + pub fn with_quic(mut self, config: QuicConfig) -> Self { + self.quic = Some(config); + self + } + + /// Add WebRTC transport configuration, enabling the transport. + pub fn with_webrtc(mut self, config: WebRtcConfig) -> Self { + self.webrtc = Some(config); + self + } + + /// Add WebSocket transport configuration, enabling the transport. + pub fn with_websocket(mut self, config: WebSocketConfig) -> Self { + self.websocket = Some(config); + self + } + + /// Add keypair. + /// + /// If no keypair is specified, litep2p creates a new keypair. + pub fn with_keypair(mut self, keypair: Keypair) -> Self { + self.keypair = Some(keypair); + self + } + + /// Enable notification protocol. + pub fn with_notification_protocol(mut self, config: notification::Config) -> Self { + self.notification_protocols.insert(config.protocol_name().clone(), config); + self + } + + /// Enable IPFS Ping protocol. + pub fn with_libp2p_ping(mut self, config: ping::Config) -> Self { + self.ping = Some(config); + self + } + + /// Enable IPFS Identify protocol. + pub fn with_libp2p_identify(mut self, config: identify::Config) -> Self { + self.identify = Some(config); + self + } + + /// Enable IPFS Kademlia protocol. + pub fn with_libp2p_kademlia(mut self, config: kademlia::Config) -> Self { + self.kademlia = Some(config); + self + } + + /// Enable IPFS Bitswap protocol. + pub fn with_libp2p_bitswap(mut self, config: bitswap::Config) -> Self { + self.bitswap = Some(config); + self + } + + /// Enable request-response protocol. + pub fn with_request_response_protocol(mut self, config: request_response::Config) -> Self { + self.request_response_protocols.insert(config.protocol_name().clone(), config); + self + } + + /// Enable user protocol. + pub fn with_user_protocol(mut self, protocol: Box) -> Self { + self.user_protocols.insert(protocol.protocol(), protocol); + self + } + + /// Enable mDNS for peer discoveries in the local network. + pub fn with_mdns(mut self, config: MdnsConfig) -> Self { + self.mdns = Some(config); + self + } + + /// Add known address(es) for one or more peers. + pub fn with_known_addresses( + mut self, + addresses: impl Iterator)>, + ) -> Self { + self.known_addresses = addresses.collect(); + self + } + + /// Add executor for running futures spawned by `litep2p`. + /// + /// If no executor is specified, `litep2p` defaults to calling `tokio::spawn()`. + pub fn with_executor(mut self, executor: Arc) -> Self { + self.executor = Some(executor); + self + } + + /// How many addresses should litep2p attempt to dial in parallel. + pub fn with_max_parallel_dials(mut self, max_parallel_dials: usize) -> Self { + self.max_parallel_dials = max_parallel_dials; + self + } + + /// Build [`Litep2pConfig`]. + pub fn build(mut self) -> Litep2pConfig { + let keypair = match self.keypair { + Some(keypair) => keypair, + None => Keypair::generate(), + }; + + Litep2pConfig { + keypair, + tcp: self.tcp.take(), + mdns: self.mdns.take(), + quic: self.quic.take(), + webrtc: self.webrtc.take(), + websocket: self.websocket.take(), + ping: self.ping.take(), + identify: self.identify.take(), + kademlia: self.kademlia.take(), + bitswap: self.bitswap.take(), + max_parallel_dials: self.max_parallel_dials, + executor: self.executor.map_or(Arc::new(DefaultExecutor {}), |executor| executor), + user_protocols: self.user_protocols, + notification_protocols: self.notification_protocols, + request_response_protocols: self.request_response_protocols, + known_addresses: self.known_addresses, + } + } } /// Configuration for [`Litep2p`](`crate::Litep2p`). pub struct Litep2pConfig { - // TCP transport configuration. - pub(crate) tcp: Option, + // TCP transport configuration. + pub(crate) tcp: Option, - /// QUIC transport config. - pub(crate) quic: Option, + /// QUIC transport config. + pub(crate) quic: Option, - /// WebRTC transport config. - pub(crate) webrtc: Option, + /// WebRTC transport config. + pub(crate) webrtc: Option, - /// WebSocket transport config. - pub(crate) websocket: Option, + /// WebSocket transport config. + pub(crate) websocket: Option, - /// Keypair. - pub(crate) keypair: Keypair, + /// Keypair. + pub(crate) keypair: Keypair, - /// Ping protocol configuration, if enabled. - pub(crate) ping: Option, + /// Ping protocol configuration, if enabled. + pub(crate) ping: Option, - /// Identify protocol configuration, if enabled. - pub(crate) identify: Option, + /// Identify protocol configuration, if enabled. + pub(crate) identify: Option, - /// Kademlia protocol configuration, if enabled. - pub(crate) kademlia: Option, + /// Kademlia protocol configuration, if enabled. + pub(crate) kademlia: Option, - /// Bitswap protocol configuration, if enabled. - pub(crate) bitswap: Option, + /// Bitswap protocol configuration, if enabled. + pub(crate) bitswap: Option, - /// Notification protocols. - pub(crate) notification_protocols: HashMap, + /// Notification protocols. + pub(crate) notification_protocols: HashMap, - /// Request-response protocols. - pub(crate) request_response_protocols: HashMap, + /// Request-response protocols. + pub(crate) request_response_protocols: HashMap, - /// User protocols. - pub(crate) user_protocols: HashMap>, + /// User protocols. + pub(crate) user_protocols: HashMap>, - /// mDNS configuration. - pub(crate) mdns: Option, + /// mDNS configuration. + pub(crate) mdns: Option, - /// Executor. - pub(crate) executor: Arc, + /// Executor. + pub(crate) executor: Arc, - /// Maximum number of parallel dial attempts. - pub(crate) max_parallel_dials: usize, + /// Maximum number of parallel dial attempts. + pub(crate) max_parallel_dials: usize, - /// Known addresses. - pub(crate) known_addresses: Vec<(PeerId, Vec)>, + /// Known addresses. + pub(crate) known_addresses: Vec<(PeerId, Vec)>, } diff --git a/src/crypto/ed25519.rs b/src/crypto/ed25519.rs index 39783ff6..852d4e0f 100644 --- a/src/crypto/ed25519.rs +++ b/src/crypto/ed25519.rs @@ -33,80 +33,83 @@ use std::{cmp, convert::TryFrom, fmt}; pub struct Keypair(ed25519::Keypair); impl Keypair { - /// Generate a new random Ed25519 keypair. - pub fn generate() -> Keypair { - Keypair::from(SecretKey::generate()) - } - - /// Encode the keypair into a byte array by concatenating the bytes - /// of the secret scalar and the compressed public point, - /// an informal standard for encoding Ed25519 keypairs. - pub fn encode(&self) -> [u8; 64] { - self.0.to_bytes() - } - - /// Decode a keypair from the [binary format](https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.5) - /// produced by [`Keypair::encode`], zeroing the input on success. - /// - /// Note that this binary format is the same as `ed25519_dalek`'s and `ed25519_zebra`'s. - pub fn decode(kp: &mut [u8]) -> crate::Result { - ed25519::Keypair::from_bytes(kp) - .map(|k| { - kp.zeroize(); - Keypair(k) - }) - .map_err(|error| Error::Other(format!("Failed to parse keypair: {error:?}"))) - } - - /// Sign a message using the private key of this keypair. - pub fn sign(&self, msg: &[u8]) -> Vec { - self.0.sign(msg).to_bytes().to_vec() - } - - /// Get the public key of this keypair. - pub fn public(&self) -> PublicKey { - PublicKey(self.0.public) - } - - /// Get the secret key of this keypair. - pub fn secret(&self) -> SecretKey { - SecretKey::from_bytes(&mut self.0.secret.to_bytes()) - .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") - } + /// Generate a new random Ed25519 keypair. + pub fn generate() -> Keypair { + Keypair::from(SecretKey::generate()) + } + + /// Encode the keypair into a byte array by concatenating the bytes + /// of the secret scalar and the compressed public point, + /// an informal standard for encoding Ed25519 keypairs. + pub fn encode(&self) -> [u8; 64] { + self.0.to_bytes() + } + + /// Decode a keypair from the [binary format](https://datatracker.ietf.org/doc/html/rfc8032#section-5.1.5) + /// produced by [`Keypair::encode`], zeroing the input on success. + /// + /// Note that this binary format is the same as `ed25519_dalek`'s and `ed25519_zebra`'s. + pub fn decode(kp: &mut [u8]) -> crate::Result { + ed25519::Keypair::from_bytes(kp) + .map(|k| { + kp.zeroize(); + Keypair(k) + }) + .map_err(|error| Error::Other(format!("Failed to parse keypair: {error:?}"))) + } + + /// Sign a message using the private key of this keypair. + pub fn sign(&self, msg: &[u8]) -> Vec { + self.0.sign(msg).to_bytes().to_vec() + } + + /// Get the public key of this keypair. + pub fn public(&self) -> PublicKey { + PublicKey(self.0.public) + } + + /// Get the secret key of this keypair. + pub fn secret(&self) -> SecretKey { + SecretKey::from_bytes(&mut self.0.secret.to_bytes()) + .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + } } impl fmt::Debug for Keypair { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Keypair").field("public", &self.0.public).finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Keypair").field("public", &self.0.public).finish() + } } impl Clone for Keypair { - fn clone(&self) -> Keypair { - let mut sk_bytes = self.0.secret.to_bytes(); - let secret = SecretKey::from_bytes(&mut sk_bytes) - .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") - .0; - let public = ed25519::PublicKey::from_bytes(&self.0.public.to_bytes()) - .expect("ed25519::PublicKey::from_bytes(to_bytes(k)) != k"); - Keypair(ed25519::Keypair { secret, public }) - } + fn clone(&self) -> Keypair { + let mut sk_bytes = self.0.secret.to_bytes(); + let secret = SecretKey::from_bytes(&mut sk_bytes) + .expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + .0; + let public = ed25519::PublicKey::from_bytes(&self.0.public.to_bytes()) + .expect("ed25519::PublicKey::from_bytes(to_bytes(k)) != k"); + Keypair(ed25519::Keypair { secret, public }) + } } /// Demote an Ed25519 keypair to a secret key. impl From for SecretKey { - fn from(kp: Keypair) -> SecretKey { - SecretKey(kp.0.secret) - } + fn from(kp: Keypair) -> SecretKey { + SecretKey(kp.0.secret) + } } /// Promote an Ed25519 secret key into a keypair. impl From for Keypair { - fn from(sk: SecretKey) -> Keypair { - let secret: ed25519::ExpandedSecretKey = (&sk.0).into(); - let public = ed25519::PublicKey::from(&secret); - Keypair(ed25519::Keypair { secret: sk.0, public }) - } + fn from(sk: SecretKey) -> Keypair { + let secret: ed25519::ExpandedSecretKey = (&sk.0).into(); + let public = ed25519::PublicKey::from(&secret); + Keypair(ed25519::Keypair { + secret: sk.0, + public, + }) + } } /// An Ed25519 public key. @@ -114,44 +117,44 @@ impl From for Keypair { pub struct PublicKey(ed25519::PublicKey); impl fmt::Debug for PublicKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("PublicKey(compressed): ")?; - for byte in self.0.as_bytes() { - write!(f, "{byte:x}")?; - } - Ok(()) - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("PublicKey(compressed): ")?; + for byte in self.0.as_bytes() { + write!(f, "{byte:x}")?; + } + Ok(()) + } } impl cmp::PartialEq for PublicKey { - fn eq(&self, other: &Self) -> bool { - self.0.as_bytes().eq(other.0.as_bytes()) - } + fn eq(&self, other: &Self) -> bool { + self.0.as_bytes().eq(other.0.as_bytes()) + } } impl PublicKey { - /// Verify the Ed25519 signature on a message using the public key. - pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool { - ed25519::Signature::try_from(sig).and_then(|s| self.0.verify(msg, &s)).is_ok() - } - - /// Encode the public key into a byte array in compressed form, i.e. - /// where one coordinate is represented by a single bit. - pub fn encode(&self) -> [u8; 32] { - self.0.to_bytes() - } - - /// Decode a public key from a byte array as produced by `encode`. - pub fn decode(k: &[u8]) -> crate::Result { - ed25519::PublicKey::from_bytes(k) - .map_err(|error| Error::Other(format!("Failed to parse keypair: {error:?}"))) - .map(PublicKey) - } - - /// Convert public key to `PeerId`. - pub fn to_peer_id(&self) -> PeerId { - crate::crypto::PublicKey::Ed25519(self.clone()).into() - } + /// Verify the Ed25519 signature on a message using the public key. + pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool { + ed25519::Signature::try_from(sig).and_then(|s| self.0.verify(msg, &s)).is_ok() + } + + /// Encode the public key into a byte array in compressed form, i.e. + /// where one coordinate is represented by a single bit. + pub fn encode(&self) -> [u8; 32] { + self.0.to_bytes() + } + + /// Decode a public key from a byte array as produced by `encode`. + pub fn decode(k: &[u8]) -> crate::Result { + ed25519::PublicKey::from_bytes(k) + .map_err(|error| Error::Other(format!("Failed to parse keypair: {error:?}"))) + .map(PublicKey) + } + + /// Convert public key to `PeerId`. + pub fn to_peer_id(&self) -> PeerId { + crate::crypto::PublicKey::Ed25519(self.clone()).into() + } } /// An Ed25519 secret key. @@ -159,115 +162,115 @@ pub struct SecretKey(ed25519::SecretKey); /// View the bytes of the secret key. impl AsRef<[u8]> for SecretKey { - fn as_ref(&self) -> &[u8] { - self.0.as_bytes() - } + fn as_ref(&self) -> &[u8] { + self.0.as_bytes() + } } impl Clone for SecretKey { - fn clone(&self) -> SecretKey { - let mut sk_bytes = self.0.to_bytes(); - Self::from_bytes(&mut sk_bytes).expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") - } + fn clone(&self) -> SecretKey { + let mut sk_bytes = self.0.to_bytes(); + Self::from_bytes(&mut sk_bytes).expect("ed25519::SecretKey::from_bytes(to_bytes(k)) != k") + } } impl fmt::Debug for SecretKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SecretKey") - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SecretKey") + } } impl SecretKey { - /// Generate a new Ed25519 secret key. - pub fn generate() -> SecretKey { - let mut bytes = [0u8; 32]; - rand::thread_rng().fill_bytes(&mut bytes); - SecretKey( - ed25519::SecretKey::from_bytes(&bytes).expect( - "this returns `Err` only if the length is wrong; the length is correct; qed", - ), - ) - } - - /// Create an Ed25519 secret key from a byte slice, zeroing the input on success. - /// If the bytes do not constitute a valid Ed25519 secret key, an error is - /// returned. - pub fn from_bytes(mut sk_bytes: impl AsMut<[u8]>) -> crate::Result { - let sk_bytes = sk_bytes.as_mut(); - let secret = ed25519::SecretKey::from_bytes(&*sk_bytes) - .map_err(|error| Error::Other(format!("Failed to parse keypair: {error:?}")))?; - sk_bytes.zeroize(); - Ok(SecretKey(secret)) - } + /// Generate a new Ed25519 secret key. + pub fn generate() -> SecretKey { + let mut bytes = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut bytes); + SecretKey( + ed25519::SecretKey::from_bytes(&bytes).expect( + "this returns `Err` only if the length is wrong; the length is correct; qed", + ), + ) + } + + /// Create an Ed25519 secret key from a byte slice, zeroing the input on success. + /// If the bytes do not constitute a valid Ed25519 secret key, an error is + /// returned. + pub fn from_bytes(mut sk_bytes: impl AsMut<[u8]>) -> crate::Result { + let sk_bytes = sk_bytes.as_mut(); + let secret = ed25519::SecretKey::from_bytes(&*sk_bytes) + .map_err(|error| Error::Other(format!("Failed to parse keypair: {error:?}")))?; + sk_bytes.zeroize(); + Ok(SecretKey(secret)) + } } #[cfg(test)] mod tests { - use super::*; - use quickcheck::*; - - fn eq_keypairs(kp1: &Keypair, kp2: &Keypair) -> bool { - kp1.public() == kp2.public() && kp1.0.secret.as_bytes() == kp2.0.secret.as_bytes() - } - - #[test] - fn ed25519_keypair_encode_decode() { - fn prop() -> bool { - let kp1 = Keypair::generate(); - let mut kp1_enc = kp1.encode(); - let kp2 = Keypair::decode(&mut kp1_enc).unwrap(); - eq_keypairs(&kp1, &kp2) && kp1_enc.iter().all(|b| *b == 0) - } - QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); - } - - #[test] - fn ed25519_keypair_from_secret() { - fn prop() -> bool { - let kp1 = Keypair::generate(); - let mut sk = kp1.0.secret.to_bytes(); - let kp2 = Keypair::from(SecretKey::from_bytes(&mut sk).unwrap()); - eq_keypairs(&kp1, &kp2) && sk == [0u8; 32] - } - QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); - } - - #[test] - fn ed25519_signature() { - let kp = Keypair::generate(); - let pk = kp.public(); - - let msg = "hello world".as_bytes(); - let sig = kp.sign(msg); - assert!(pk.verify(msg, &sig)); - - let mut invalid_sig = sig.clone(); - invalid_sig[3..6].copy_from_slice(&[10, 23, 42]); - assert!(!pk.verify(msg, &invalid_sig)); - - let invalid_msg = "h3ll0 w0rld".as_bytes(); - assert!(!pk.verify(invalid_msg, &sig)); - } - - #[test] - fn secret_key() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let key = Keypair::generate(); - tracing::trace!("keypair: {:?}", key); - tracing::trace!("secret: {:?}", key.secret()); - tracing::trace!("public: {:?}", key.public()); - - let new_key = Keypair::from(key.secret()); - assert!(new_key.secret().as_ref() == key.secret().as_ref()); - assert!(new_key.public() == key.public()); - - let new_secret = SecretKey::from(new_key.clone()); - assert!(new_secret.as_ref() == new_key.secret().as_ref()); - - let cloned_secret = new_secret.clone(); - assert!(cloned_secret.as_ref() == new_secret.as_ref()); - } + use super::*; + use quickcheck::*; + + fn eq_keypairs(kp1: &Keypair, kp2: &Keypair) -> bool { + kp1.public() == kp2.public() && kp1.0.secret.as_bytes() == kp2.0.secret.as_bytes() + } + + #[test] + fn ed25519_keypair_encode_decode() { + fn prop() -> bool { + let kp1 = Keypair::generate(); + let mut kp1_enc = kp1.encode(); + let kp2 = Keypair::decode(&mut kp1_enc).unwrap(); + eq_keypairs(&kp1, &kp2) && kp1_enc.iter().all(|b| *b == 0) + } + QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); + } + + #[test] + fn ed25519_keypair_from_secret() { + fn prop() -> bool { + let kp1 = Keypair::generate(); + let mut sk = kp1.0.secret.to_bytes(); + let kp2 = Keypair::from(SecretKey::from_bytes(&mut sk).unwrap()); + eq_keypairs(&kp1, &kp2) && sk == [0u8; 32] + } + QuickCheck::new().tests(10).quickcheck(prop as fn() -> _); + } + + #[test] + fn ed25519_signature() { + let kp = Keypair::generate(); + let pk = kp.public(); + + let msg = "hello world".as_bytes(); + let sig = kp.sign(msg); + assert!(pk.verify(msg, &sig)); + + let mut invalid_sig = sig.clone(); + invalid_sig[3..6].copy_from_slice(&[10, 23, 42]); + assert!(!pk.verify(msg, &invalid_sig)); + + let invalid_msg = "h3ll0 w0rld".as_bytes(); + assert!(!pk.verify(invalid_msg, &sig)); + } + + #[test] + fn secret_key() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let key = Keypair::generate(); + tracing::trace!("keypair: {:?}", key); + tracing::trace!("secret: {:?}", key.secret()); + tracing::trace!("public: {:?}", key.public()); + + let new_key = Keypair::from(key.secret()); + assert!(new_key.secret().as_ref() == key.secret().as_ref()); + assert!(new_key.public() == key.public()); + + let new_secret = SecretKey::from(new_key.clone()); + assert!(new_secret.as_ref() == new_key.secret().as_ref()); + + let cloned_secret = new_secret.clone(); + assert!(cloned_secret.as_ref() == new_secret.as_ref()); + } } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index fe167519..2ad7cf18 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -27,86 +27,86 @@ pub mod ed25519; pub(crate) mod noise; pub(crate) mod tls; pub(crate) mod keys_proto { - include!(concat!(env!("OUT_DIR"), "/keys_proto.rs")); + include!(concat!(env!("OUT_DIR"), "/keys_proto.rs")); } /// The public key of a node's identity keypair. #[derive(Clone, Debug, PartialEq, Eq)] pub enum PublicKey { - /// A public Ed25519 key. - Ed25519(ed25519::PublicKey), + /// A public Ed25519 key. + Ed25519(ed25519::PublicKey), } impl PublicKey { - /// Verify a signature for a message using this public key, i.e. check - /// that the signature has been produced by the corresponding - /// private key (authenticity), and that the message has not been - /// tampered with (integrity). - #[must_use] - pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool { - use PublicKey::*; - match self { - Ed25519(pk) => pk.verify(msg, sig), - } - } - - /// Encode the public key into a protobuf structure for storage or - /// exchange with other nodes. - pub fn to_protobuf_encoding(&self) -> Vec { - use prost::Message; - - let public_key = keys_proto::PublicKey::from(self); - - let mut buf = Vec::with_capacity(public_key.encoded_len()); - public_key.encode(&mut buf).expect("Vec provides capacity as needed"); - buf - } - - /// Decode a public key from a protobuf structure, e.g. read from storage - /// or received from another node. - pub fn from_protobuf_encoding(bytes: &[u8]) -> crate::Result { - use prost::Message; - - let pubkey = keys_proto::PublicKey::decode(bytes) - .map_err(|error| Error::Other(format!("Invalid Protobuf: {error:?}")))?; - - pubkey.try_into() - } - - /// Convert the `PublicKey` into the corresponding `PeerId`. - pub fn to_peer_id(&self) -> PeerId { - self.into() - } + /// Verify a signature for a message using this public key, i.e. check + /// that the signature has been produced by the corresponding + /// private key (authenticity), and that the message has not been + /// tampered with (integrity). + #[must_use] + pub fn verify(&self, msg: &[u8], sig: &[u8]) -> bool { + use PublicKey::*; + match self { + Ed25519(pk) => pk.verify(msg, sig), + } + } + + /// Encode the public key into a protobuf structure for storage or + /// exchange with other nodes. + pub fn to_protobuf_encoding(&self) -> Vec { + use prost::Message; + + let public_key = keys_proto::PublicKey::from(self); + + let mut buf = Vec::with_capacity(public_key.encoded_len()); + public_key.encode(&mut buf).expect("Vec provides capacity as needed"); + buf + } + + /// Decode a public key from a protobuf structure, e.g. read from storage + /// or received from another node. + pub fn from_protobuf_encoding(bytes: &[u8]) -> crate::Result { + use prost::Message; + + let pubkey = keys_proto::PublicKey::decode(bytes) + .map_err(|error| Error::Other(format!("Invalid Protobuf: {error:?}")))?; + + pubkey.try_into() + } + + /// Convert the `PublicKey` into the corresponding `PeerId`. + pub fn to_peer_id(&self) -> PeerId { + self.into() + } } impl From<&PublicKey> for keys_proto::PublicKey { - fn from(key: &PublicKey) -> Self { - match key { - PublicKey::Ed25519(key) => keys_proto::PublicKey { - r#type: keys_proto::KeyType::Ed25519 as i32, - data: key.encode().to_vec(), - }, - } - } + fn from(key: &PublicKey) -> Self { + match key { + PublicKey::Ed25519(key) => keys_proto::PublicKey { + r#type: keys_proto::KeyType::Ed25519 as i32, + data: key.encode().to_vec(), + }, + } + } } impl TryFrom for PublicKey { - type Error = Error; - - fn try_from(pubkey: keys_proto::PublicKey) -> Result { - let key_type = keys_proto::KeyType::from_i32(pubkey.r#type) - .ok_or_else(|| Error::Other(format!("Unknown key type: {}", pubkey.r#type)))?; - - match key_type { - keys_proto::KeyType::Ed25519 => - Ok(ed25519::PublicKey::decode(&pubkey.data).map(PublicKey::Ed25519)?), - _ => unimplemented!("unsupported key type"), - } - } + type Error = Error; + + fn try_from(pubkey: keys_proto::PublicKey) -> Result { + let key_type = keys_proto::KeyType::from_i32(pubkey.r#type) + .ok_or_else(|| Error::Other(format!("Unknown key type: {}", pubkey.r#type)))?; + + match key_type { + keys_proto::KeyType::Ed25519 => + Ok(ed25519::PublicKey::decode(&pubkey.data).map(PublicKey::Ed25519)?), + _ => unimplemented!("unsupported key type"), + } + } } impl From for PublicKey { - fn from(public_key: ed25519::PublicKey) -> Self { - PublicKey::Ed25519(public_key) - } + fn from(public_key: ed25519::PublicKey) -> Self { + PublicKey::Ed25519(public_key) + } } diff --git a/src/crypto/noise/mod.rs b/src/crypto/noise/mod.rs index 05e5590a..558ef839 100644 --- a/src/crypto/noise/mod.rs +++ b/src/crypto/noise/mod.rs @@ -22,9 +22,9 @@ //! Noise handshake and transport implementations. use crate::{ - config::Role, - crypto::{ed25519::Keypair, PublicKey}, - error, PeerId, + config::Role, + crypto::{ed25519::Keypair, PublicKey}, + error, PeerId, }; use bytes::{Buf, Bytes, BytesMut}; @@ -33,16 +33,16 @@ use prost::Message; use snow::{Builder, HandshakeState, TransportState}; use std::{ - fmt, io, - pin::Pin, - task::{Context, Poll}, + fmt, io, + pin::Pin, + task::{Context, Poll}, }; mod protocol; mod x25519_spec; mod handshake_schema { - include!(concat!(env!("OUT_DIR"), "/noise.rs")); + include!(concat!(env!("OUT_DIR"), "/noise.rs")); } /// Noise parameters. @@ -74,673 +74,730 @@ const LOG_TARGET: &str = "litep2p::crypto::noise"; #[derive(Debug)] enum NoiseState { - Handshake(HandshakeState), - Transport(TransportState), + Handshake(HandshakeState), + Transport(TransportState), } pub struct NoiseContext { - keypair: snow::Keypair, - noise: NoiseState, - role: Role, - pub payload: Vec, + keypair: snow::Keypair, + noise: NoiseState, + role: Role, + pub payload: Vec, } impl fmt::Debug for NoiseContext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NoiseContext") - .field("public", &self.noise) - .field("payload", &self.payload) - .field("role", &self.role) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NoiseContext") + .field("public", &self.noise) + .field("payload", &self.payload) + .field("role", &self.role) + .finish() + } } impl NoiseContext { - /// Assemble Noise payload and return [`NoiseContext`]. - fn assemble( - noise: snow::HandshakeState, - keypair: snow::Keypair, - id_keys: &Keypair, - role: Role, - ) -> Self { - let noise_payload = handshake_schema::NoiseHandshakePayload { - identity_key: Some(PublicKey::Ed25519(id_keys.public()).to_protobuf_encoding()), - identity_sig: Some( - id_keys.sign(&[STATIC_KEY_DOMAIN.as_bytes(), keypair.public.as_ref()].concat()), - ), - ..Default::default() - }; - - let mut payload = Vec::with_capacity(noise_payload.encoded_len()); - noise_payload.encode(&mut payload).expect("Vec to provide needed capacity"); - - Self { noise: NoiseState::Handshake(noise), keypair, payload, role } - } - - // fn new(role: Role) -> Self { - pub fn new(keypair: &Keypair, role: Role) -> Self { - tracing::trace!(target: LOG_TARGET, ?role, "create new noise configuration"); - - // let builder: Builder<'_> = - // Builder::new(NOISE_PARAMETERS.parse().expect("valid Noise pattern")); - let builder: Builder<'_> = Builder::with_resolver( - NOISE_PARAMETERS.parse().expect("valid Noise pattern"), - Box::new(protocol::Resolver), - ); - - let dh_keypair = builder.generate_keypair().expect("keypair generation to succeed"); - let static_key = &dh_keypair.private; - - let noise = match role { - Role::Dialer => builder - .local_private_key(static_key) - .build_initiator() - .expect("initialization to succeed"), - Role::Listener => builder - .local_private_key(static_key) - .build_responder() - .expect("initialization to succeed"), - }; - - Self::assemble(noise, dh_keypair, keypair, role) - } - - /// Create new [`NoiseContext`] with prologue. - pub fn with_prologue(id_keys: &Keypair, prologue: Vec) -> Self { - let noise: Builder<'_> = Builder::with_resolver( - NOISE_PARAMETERS.parse().expect("valid Noise pattern"), - Box::new(protocol::Resolver), - ); - - // let noise = snow::Builder::new(NOISE_PARAMETERS.parse().expect("valid Noise patterns")); - let keypair = noise.generate_keypair().unwrap(); - - let noise = noise - .local_private_key(&keypair.private) - .prologue(&prologue) - .build_initiator() - .expect("to succeed"); - - Self::assemble(noise, keypair, id_keys, Role::Dialer) - } - - /// Get remote public key from the received Noise payload. - // TODO: refactor - pub fn get_remote_public_key(&mut self, reply: &Vec) -> crate::Result { - if reply.len() <= 2 { - return Err(error::Error::InvalidData); - } - - // TODO: no unwraps - let size: Result<[u8; 2], _> = reply[0..2].try_into(); - let _size = u16::from_be_bytes(size.unwrap()); - - // TODO: buffer size - let mut inner = vec![0u8; 1024]; - - let NoiseState::Handshake(ref mut noise) = self.noise else { - panic!("invalid state to read the second handshake message"); - }; - - let res = noise.read_message(&reply[2..], &mut inner)?; - inner.truncate(res); - - let payload = handshake_schema::NoiseHandshakePayload::decode(inner.as_slice())?; - - Ok(PublicKey::from_protobuf_encoding( - &payload - .identity_key - .ok_or(error::Error::NegotiationError(error::NegotiationError::PeerIdMissing))?, - )?) - } - - /// Get first message. - /// - /// Listener only sends one message (the payload) - pub fn first_message(&mut self, role: Role) -> Vec { - match role { - Role::Dialer => { - tracing::trace!(target: LOG_TARGET, "get noise dialer first message"); - - let NoiseState::Handshake(ref mut noise) = self.noise else { - panic!("invalid state to read the second handshake message"); - }; - - let mut buffer = vec![0u8; 256]; - let nwritten = noise.write_message(&[], &mut buffer).expect("to succeed"); - buffer.truncate(nwritten); - - let size = nwritten as u16; - let mut size = size.to_be_bytes().to_vec(); - size.append(&mut buffer); - - size - }, - Role::Listener => self.second_message(), - } - } - - /// Get second message. - /// - /// Only the dialer sends the second message. - pub fn second_message(&mut self) -> Vec { - tracing::trace!(target: LOG_TARGET, "get noise paylod message"); - - let NoiseState::Handshake(ref mut noise) = self.noise else { - panic!("invalid state to read the second handshake message"); - }; - - let mut buffer = vec![0u8; 2048]; - let nwritten = noise.write_message(&self.payload, &mut buffer).expect("to succeed"); - buffer.truncate(nwritten); - - let size = nwritten as u16; - let mut size = size.to_be_bytes().to_vec(); - size.append(&mut buffer); - - size - } - - /// Read handshake message. - async fn read_handshake_message( - &mut self, - io: &mut T, - ) -> crate::Result { - let mut size = BytesMut::zeroed(2); - io.read_exact(&mut size).await?; - let size = size.get_u16(); - - let mut message = BytesMut::zeroed(size as usize); - io.read_exact(&mut message).await?; - - let mut out = BytesMut::new(); - out.resize(message.len() + 200, 0u8); // TODO: correct overhead - - let NoiseState::Handshake(ref mut noise) = self.noise else { - panic!("invalid state to read handshake message"); - }; - - let nread = noise.read_message(&message, &mut out)?; - out.truncate(nread); - - Ok(out.freeze()) - } - - fn read_message(&mut self, message: &[u8], out: &mut [u8]) -> Result { - match self.noise { - NoiseState::Handshake(ref mut noise) => noise.read_message(message, out), - NoiseState::Transport(ref mut noise) => noise.read_message(message, out), - } - } - - fn write_message(&mut self, message: &[u8], out: &mut [u8]) -> Result { - match self.noise { - NoiseState::Handshake(ref mut noise) => noise.write_message(message, out), - NoiseState::Transport(ref mut noise) => noise.write_message(message, out), - } - } - - /// Convert Noise into transport mode. - fn into_transport(self) -> NoiseContext { - let transport = match self.noise { - NoiseState::Handshake(noise) => noise.into_transport_mode().unwrap(), - NoiseState::Transport(_) => panic!("invalid state"), - }; - - NoiseContext { - keypair: self.keypair, - payload: self.payload, - role: self.role, - noise: NoiseState::Transport(transport), - } - } + /// Assemble Noise payload and return [`NoiseContext`]. + fn assemble( + noise: snow::HandshakeState, + keypair: snow::Keypair, + id_keys: &Keypair, + role: Role, + ) -> Self { + let noise_payload = handshake_schema::NoiseHandshakePayload { + identity_key: Some(PublicKey::Ed25519(id_keys.public()).to_protobuf_encoding()), + identity_sig: Some( + id_keys.sign(&[STATIC_KEY_DOMAIN.as_bytes(), keypair.public.as_ref()].concat()), + ), + ..Default::default() + }; + + let mut payload = Vec::with_capacity(noise_payload.encoded_len()); + noise_payload.encode(&mut payload).expect("Vec to provide needed capacity"); + + Self { + noise: NoiseState::Handshake(noise), + keypair, + payload, + role, + } + } + + // fn new(role: Role) -> Self { + pub fn new(keypair: &Keypair, role: Role) -> Self { + tracing::trace!(target: LOG_TARGET, ?role, "create new noise configuration"); + + // let builder: Builder<'_> = + // Builder::new(NOISE_PARAMETERS.parse().expect("valid Noise pattern")); + let builder: Builder<'_> = Builder::with_resolver( + NOISE_PARAMETERS.parse().expect("valid Noise pattern"), + Box::new(protocol::Resolver), + ); + + let dh_keypair = builder.generate_keypair().expect("keypair generation to succeed"); + let static_key = &dh_keypair.private; + + let noise = match role { + Role::Dialer => builder + .local_private_key(static_key) + .build_initiator() + .expect("initialization to succeed"), + Role::Listener => builder + .local_private_key(static_key) + .build_responder() + .expect("initialization to succeed"), + }; + + Self::assemble(noise, dh_keypair, keypair, role) + } + + /// Create new [`NoiseContext`] with prologue. + pub fn with_prologue(id_keys: &Keypair, prologue: Vec) -> Self { + let noise: Builder<'_> = Builder::with_resolver( + NOISE_PARAMETERS.parse().expect("valid Noise pattern"), + Box::new(protocol::Resolver), + ); + + // let noise = snow::Builder::new(NOISE_PARAMETERS.parse().expect("valid Noise patterns")); + let keypair = noise.generate_keypair().unwrap(); + + let noise = noise + .local_private_key(&keypair.private) + .prologue(&prologue) + .build_initiator() + .expect("to succeed"); + + Self::assemble(noise, keypair, id_keys, Role::Dialer) + } + + /// Get remote public key from the received Noise payload. + // TODO: refactor + pub fn get_remote_public_key(&mut self, reply: &Vec) -> crate::Result { + if reply.len() <= 2 { + return Err(error::Error::InvalidData); + } + + // TODO: no unwraps + let size: Result<[u8; 2], _> = reply[0..2].try_into(); + let _size = u16::from_be_bytes(size.unwrap()); + + // TODO: buffer size + let mut inner = vec![0u8; 1024]; + + let NoiseState::Handshake(ref mut noise) = self.noise else { + panic!("invalid state to read the second handshake message"); + }; + + let res = noise.read_message(&reply[2..], &mut inner)?; + inner.truncate(res); + + let payload = handshake_schema::NoiseHandshakePayload::decode(inner.as_slice())?; + + Ok(PublicKey::from_protobuf_encoding( + &payload.identity_key.ok_or(error::Error::NegotiationError( + error::NegotiationError::PeerIdMissing, + ))?, + )?) + } + + /// Get first message. + /// + /// Listener only sends one message (the payload) + pub fn first_message(&mut self, role: Role) -> Vec { + match role { + Role::Dialer => { + tracing::trace!(target: LOG_TARGET, "get noise dialer first message"); + + let NoiseState::Handshake(ref mut noise) = self.noise else { + panic!("invalid state to read the second handshake message"); + }; + + let mut buffer = vec![0u8; 256]; + let nwritten = noise.write_message(&[], &mut buffer).expect("to succeed"); + buffer.truncate(nwritten); + + let size = nwritten as u16; + let mut size = size.to_be_bytes().to_vec(); + size.append(&mut buffer); + + size + } + Role::Listener => self.second_message(), + } + } + + /// Get second message. + /// + /// Only the dialer sends the second message. + pub fn second_message(&mut self) -> Vec { + tracing::trace!(target: LOG_TARGET, "get noise paylod message"); + + let NoiseState::Handshake(ref mut noise) = self.noise else { + panic!("invalid state to read the second handshake message"); + }; + + let mut buffer = vec![0u8; 2048]; + let nwritten = noise.write_message(&self.payload, &mut buffer).expect("to succeed"); + buffer.truncate(nwritten); + + let size = nwritten as u16; + let mut size = size.to_be_bytes().to_vec(); + size.append(&mut buffer); + + size + } + + /// Read handshake message. + async fn read_handshake_message( + &mut self, + io: &mut T, + ) -> crate::Result { + let mut size = BytesMut::zeroed(2); + io.read_exact(&mut size).await?; + let size = size.get_u16(); + + let mut message = BytesMut::zeroed(size as usize); + io.read_exact(&mut message).await?; + + let mut out = BytesMut::new(); + out.resize(message.len() + 200, 0u8); // TODO: correct overhead + + let NoiseState::Handshake(ref mut noise) = self.noise else { + panic!("invalid state to read handshake message"); + }; + + let nread = noise.read_message(&message, &mut out)?; + out.truncate(nread); + + Ok(out.freeze()) + } + + fn read_message(&mut self, message: &[u8], out: &mut [u8]) -> Result { + match self.noise { + NoiseState::Handshake(ref mut noise) => noise.read_message(message, out), + NoiseState::Transport(ref mut noise) => noise.read_message(message, out), + } + } + + fn write_message(&mut self, message: &[u8], out: &mut [u8]) -> Result { + match self.noise { + NoiseState::Handshake(ref mut noise) => noise.write_message(message, out), + NoiseState::Transport(ref mut noise) => noise.write_message(message, out), + } + } + + /// Convert Noise into transport mode. + fn into_transport(self) -> NoiseContext { + let transport = match self.noise { + NoiseState::Handshake(noise) => noise.into_transport_mode().unwrap(), + NoiseState::Transport(_) => panic!("invalid state"), + }; + + NoiseContext { + keypair: self.keypair, + payload: self.payload, + role: self.role, + noise: NoiseState::Transport(transport), + } + } } enum ReadState { - ReadData { max_read: usize }, - ReadFrameLen, - ProcessNextFrame { pending: Option>, offset: usize, size: usize, frame_size: usize }, + ReadData { + max_read: usize, + }, + ReadFrameLen, + ProcessNextFrame { + pending: Option>, + offset: usize, + size: usize, + frame_size: usize, + }, } enum WriteState { - Ready { offset: usize, size: usize, encrypted_size: usize }, - WriteFrame { offset: usize, size: usize, encrypted_size: usize }, + Ready { + offset: usize, + size: usize, + encrypted_size: usize, + }, + WriteFrame { + offset: usize, + size: usize, + encrypted_size: usize, + }, } pub struct NoiseSocket { - io: S, - noise: NoiseContext, - current_frame_size: Option, - write_state: WriteState, - encrypt_buffer: Vec, - offset: usize, - nread: usize, - read_state: ReadState, - read_buffer: Vec, - canonical_max_read: usize, - decrypt_buffer: Option>, + io: S, + noise: NoiseContext, + current_frame_size: Option, + write_state: WriteState, + encrypt_buffer: Vec, + offset: usize, + nread: usize, + read_state: ReadState, + read_buffer: Vec, + canonical_max_read: usize, + decrypt_buffer: Option>, } impl NoiseSocket { - fn new( - io: S, - noise: NoiseContext, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - ) -> Self { - Self { - io, - noise, - read_buffer: vec![ - 0u8; - max_read_ahead_factor * MAX_NOISE_MSG_LEN + (2 + MAX_NOISE_MSG_LEN) - ], - nread: 0usize, - offset: 0usize, - current_frame_size: None, - write_state: WriteState::Ready { offset: 0usize, size: 0usize, encrypted_size: 0usize }, - encrypt_buffer: vec![0u8; max_write_buffer_size * (MAX_NOISE_MSG_LEN + 2)], - decrypt_buffer: Some(vec![0u8; MAX_FRAME_LEN]), - read_state: ReadState::ReadData { max_read: max_read_ahead_factor * MAX_NOISE_MSG_LEN }, - canonical_max_read: max_read_ahead_factor * MAX_NOISE_MSG_LEN, - } - } - - fn reset_read_state(&mut self, remaining: usize) { - match remaining { - 0 => { - self.nread = 0; - }, - 1 => { - self.read_buffer[0] = self.read_buffer[self.nread - 1]; - self.nread = 1; - }, - _ => panic!("invalid state"), - } - - self.offset = 0; - self.read_state = ReadState::ReadData { max_read: self.canonical_max_read }; - } + fn new( + io: S, + noise: NoiseContext, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + ) -> Self { + Self { + io, + noise, + read_buffer: vec![ + 0u8; + max_read_ahead_factor * MAX_NOISE_MSG_LEN + (2 + MAX_NOISE_MSG_LEN) + ], + nread: 0usize, + offset: 0usize, + current_frame_size: None, + write_state: WriteState::Ready { + offset: 0usize, + size: 0usize, + encrypted_size: 0usize, + }, + encrypt_buffer: vec![0u8; max_write_buffer_size * (MAX_NOISE_MSG_LEN + 2)], + decrypt_buffer: Some(vec![0u8; MAX_FRAME_LEN]), + read_state: ReadState::ReadData { + max_read: max_read_ahead_factor * MAX_NOISE_MSG_LEN, + }, + canonical_max_read: max_read_ahead_factor * MAX_NOISE_MSG_LEN, + } + } + + fn reset_read_state(&mut self, remaining: usize) { + match remaining { + 0 => { + self.nread = 0; + } + 1 => { + self.read_buffer[0] = self.read_buffer[self.nread - 1]; + self.nread = 1; + } + _ => panic!("invalid state"), + } + + self.offset = 0; + self.read_state = ReadState::ReadData { + max_read: self.canonical_max_read, + }; + } } impl AsyncRead for NoiseSocket { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let this = Pin::into_inner(self); - - loop { - match this.read_state { - ReadState::ReadData { max_read } => { - let nread = match Pin::new(&mut this.io) - .poll_read(cx, &mut this.read_buffer[this.nread..max_read]) - { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(error)) => return Poll::Ready(Err(error)), - Poll::Ready(Ok(nread)) => match nread == 0 { - true => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), - false => nread, - }, - }; - - tracing::trace!(target: LOG_TARGET, ?nread, "read data from socket"); - - this.nread += nread; - this.read_state = ReadState::ReadFrameLen; - }, - ReadState::ReadFrameLen => { - let mut remaining = match this.nread.checked_sub(this.offset) { - Some(remaining) => remaining, - None => { - tracing::error!(target: LOG_TARGET, "offset is larger than the number of bytes read"); - return Poll::Ready(Err(io::ErrorKind::PermissionDenied.into())); - }, - }; - - if remaining < 2 { - tracing::trace!(target: LOG_TARGET, "reset read buffer"); - this.reset_read_state(remaining); - continue; - } - - // get frame size, either from current or previous iteration - let frame_size = match this.current_frame_size.take() { - Some(frame_size) => frame_size, - None => { - let frame_size = (this.read_buffer[this.offset] as u16) << 8 | - this.read_buffer[this.offset + 1] as u16; - this.offset += 2; - remaining -= 2; - frame_size as usize - }, - }; - - tracing::trace!(target: LOG_TARGET, "current frame size = {frame_size}"); - - if remaining < frame_size { - // `read_buffer` can fit the full frame size. - if this.nread + frame_size < this.canonical_max_read { - tracing::trace!( - target: LOG_TARGET, - max_size = ?this.canonical_max_read, - next_frame_size = ?(this.nread + frame_size), - "read buffer can fit the full frame", - ); - - this.current_frame_size = Some(frame_size); - this.read_state = - ReadState::ReadData { max_read: this.canonical_max_read }; - continue; - } - - tracing::trace!(target: LOG_TARGET, "use auxiliary buffer extension"); - - // use the auxiliary memory at the end of the read buffer for reading the - // frame - this.current_frame_size = Some(frame_size); - this.read_state = - ReadState::ReadData { max_read: this.nread + frame_size - remaining }; - continue; - } - - if frame_size <= NOISE_EXTRA_ENCRYPT_SPACE { - tracing::error!( - target: LOG_TARGET, - ?frame_size, - max_size = ?NOISE_EXTRA_ENCRYPT_SPACE, - "invalid frame size", - ); - println!("invalid frame size"); - return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); - } - - this.current_frame_size = Some(frame_size); - this.read_state = ReadState::ProcessNextFrame { - pending: None, - offset: 0usize, - size: 0usize, - frame_size: 0usize, - }; - }, - ReadState::ProcessNextFrame { ref mut pending, offset, size, frame_size } => - match pending.take() { - Some(pending) => match buf.len() >= pending[offset..size].len() { - true => { - let copy_size = pending[offset..size].len(); - buf[..copy_size] - .copy_from_slice(&pending[offset..copy_size + offset]); - - this.read_state = ReadState::ReadFrameLen; - this.decrypt_buffer = Some(pending); - this.offset += frame_size; - return Poll::Ready(Ok(copy_size)); - }, - false => { - buf.copy_from_slice(&pending[offset..buf.len() + offset]); - - this.read_state = ReadState::ProcessNextFrame { - pending: Some(pending), - offset: offset + buf.len(), - size, - frame_size, - }; - return Poll::Ready(Ok(buf.len())); - }, - }, - None => { - let frame_size = - this.current_frame_size.take().expect("`frame_size` to exist"); - - match buf.len() >= frame_size - NOISE_EXTRA_ENCRYPT_SPACE { - true => match this.noise.read_message( - &this.read_buffer[this.offset..this.offset + frame_size], - buf, - ) { - Err(error) => { - tracing::error!(target: LOG_TARGET, ?error, "failed to decrypt message"); - return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); - }, - Ok(nread) => { - this.offset += frame_size; - this.read_state = ReadState::ReadFrameLen; - return Poll::Ready(Ok(nread)); - }, - }, - false => { - let mut buffer = - this.decrypt_buffer.take().expect("buffer to exist"); - - match this.noise.read_message( - &this.read_buffer[this.offset..this.offset + frame_size], - &mut buffer, - ) { - Err(error) => { - tracing::error!(target: LOG_TARGET, ?error, "failed to decrypt message"); - return Poll::Ready(Err( - io::ErrorKind::InvalidData.into() - )); - }, - Ok(nread) => { - buf.copy_from_slice(&buffer[..buf.len()]); - this.read_state = ReadState::ProcessNextFrame { - pending: Some(buffer), - offset: buf.len(), - size: nread, - frame_size, - }; - return Poll::Ready(Ok(buf.len())); - }, - } - }, - } - }, - }, - } - } - } + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = Pin::into_inner(self); + + loop { + match this.read_state { + ReadState::ReadData { max_read } => { + let nread = match Pin::new(&mut this.io) + .poll_read(cx, &mut this.read_buffer[this.nread..max_read]) + { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(error)) => return Poll::Ready(Err(error)), + Poll::Ready(Ok(nread)) => match nread == 0 { + true => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), + false => nread, + }, + }; + + tracing::trace!(target: LOG_TARGET, ?nread, "read data from socket"); + + this.nread += nread; + this.read_state = ReadState::ReadFrameLen; + } + ReadState::ReadFrameLen => { + let mut remaining = match this.nread.checked_sub(this.offset) { + Some(remaining) => remaining, + None => { + tracing::error!(target: LOG_TARGET, "offset is larger than the number of bytes read"); + return Poll::Ready(Err(io::ErrorKind::PermissionDenied.into())); + } + }; + + if remaining < 2 { + tracing::trace!(target: LOG_TARGET, "reset read buffer"); + this.reset_read_state(remaining); + continue; + } + + // get frame size, either from current or previous iteration + let frame_size = match this.current_frame_size.take() { + Some(frame_size) => frame_size, + None => { + let frame_size = (this.read_buffer[this.offset] as u16) << 8 + | this.read_buffer[this.offset + 1] as u16; + this.offset += 2; + remaining -= 2; + frame_size as usize + } + }; + + tracing::trace!(target: LOG_TARGET, "current frame size = {frame_size}"); + + if remaining < frame_size { + // `read_buffer` can fit the full frame size. + if this.nread + frame_size < this.canonical_max_read { + tracing::trace!( + target: LOG_TARGET, + max_size = ?this.canonical_max_read, + next_frame_size = ?(this.nread + frame_size), + "read buffer can fit the full frame", + ); + + this.current_frame_size = Some(frame_size); + this.read_state = ReadState::ReadData { + max_read: this.canonical_max_read, + }; + continue; + } + + tracing::trace!(target: LOG_TARGET, "use auxiliary buffer extension"); + + // use the auxiliary memory at the end of the read buffer for reading the + // frame + this.current_frame_size = Some(frame_size); + this.read_state = ReadState::ReadData { + max_read: this.nread + frame_size - remaining, + }; + continue; + } + + if frame_size <= NOISE_EXTRA_ENCRYPT_SPACE { + tracing::error!( + target: LOG_TARGET, + ?frame_size, + max_size = ?NOISE_EXTRA_ENCRYPT_SPACE, + "invalid frame size", + ); + println!("invalid frame size"); + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); + } + + this.current_frame_size = Some(frame_size); + this.read_state = ReadState::ProcessNextFrame { + pending: None, + offset: 0usize, + size: 0usize, + frame_size: 0usize, + }; + } + ReadState::ProcessNextFrame { + ref mut pending, + offset, + size, + frame_size, + } => match pending.take() { + Some(pending) => match buf.len() >= pending[offset..size].len() { + true => { + let copy_size = pending[offset..size].len(); + buf[..copy_size].copy_from_slice(&pending[offset..copy_size + offset]); + + this.read_state = ReadState::ReadFrameLen; + this.decrypt_buffer = Some(pending); + this.offset += frame_size; + return Poll::Ready(Ok(copy_size)); + } + false => { + buf.copy_from_slice(&pending[offset..buf.len() + offset]); + + this.read_state = ReadState::ProcessNextFrame { + pending: Some(pending), + offset: offset + buf.len(), + size, + frame_size, + }; + return Poll::Ready(Ok(buf.len())); + } + }, + None => { + let frame_size = + this.current_frame_size.take().expect("`frame_size` to exist"); + + match buf.len() >= frame_size - NOISE_EXTRA_ENCRYPT_SPACE { + true => match this.noise.read_message( + &this.read_buffer[this.offset..this.offset + frame_size], + buf, + ) { + Err(error) => { + tracing::error!(target: LOG_TARGET, ?error, "failed to decrypt message"); + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); + } + Ok(nread) => { + this.offset += frame_size; + this.read_state = ReadState::ReadFrameLen; + return Poll::Ready(Ok(nread)); + } + }, + false => { + let mut buffer = + this.decrypt_buffer.take().expect("buffer to exist"); + + match this.noise.read_message( + &this.read_buffer[this.offset..this.offset + frame_size], + &mut buffer, + ) { + Err(error) => { + tracing::error!(target: LOG_TARGET, ?error, "failed to decrypt message"); + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); + } + Ok(nread) => { + buf.copy_from_slice(&buffer[..buf.len()]); + this.read_state = ReadState::ProcessNextFrame { + pending: Some(buffer), + offset: buf.len(), + size: nread, + frame_size, + }; + return Poll::Ready(Ok(buf.len())); + } + } + } + } + } + }, + } + } + } } impl AsyncWrite for NoiseSocket { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = Pin::into_inner(self); - let mut chunks = buf.chunks(MAX_FRAME_LEN).peekable(); - - loop { - match this.write_state { - WriteState::Ready { offset, size, encrypted_size } => { - let Some(chunk) = chunks.next() else { - println!("no chunk"); - break; - }; - - match this.noise.write_message(chunk, &mut this.encrypt_buffer[offset + 2..]) { - Err(error) => { - tracing::error!(target: LOG_TARGET, ?error, "failed to encrypt message"); - return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); - }, - Ok(nwritten) => { - this.encrypt_buffer[offset + 0] = (nwritten >> 8) as u8; - this.encrypt_buffer[offset + 1] = (nwritten & 0xff) as u8; - - if let Some(next_chunk) = chunks.peek() { - if next_chunk.len() + NOISE_EXTRA_ENCRYPT_SPACE + 2 <= - this.encrypt_buffer[offset + nwritten + 2..].len() - { - this.write_state = WriteState::Ready { - offset: offset + nwritten + 2, - size: size + chunk.len(), - encrypted_size: encrypted_size + nwritten + 2, - }; - continue; - } - } - - this.write_state = WriteState::WriteFrame { - offset: 0usize, - size: size + chunk.len(), - encrypted_size: encrypted_size + nwritten + 2, - }; - }, - } - }, - WriteState::WriteFrame { ref mut offset, size, encrypted_size } => loop { - match futures::ready!(Pin::new(&mut this.io) - .poll_write(cx, &this.encrypt_buffer[*offset..encrypted_size])) - { - Ok(nwritten) => { - *offset += nwritten; - - if offset == &encrypted_size { - this.write_state = WriteState::Ready { - offset: 0usize, - size: 0usize, - encrypted_size: 0usize, - }; - return Poll::Ready(Ok(size)); - } - }, - Err(error) => return Poll::Ready(Err(error)), - } - }, - } - } - - Poll::Ready(Ok(0)) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_close(cx) - } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = Pin::into_inner(self); + let mut chunks = buf.chunks(MAX_FRAME_LEN).peekable(); + + loop { + match this.write_state { + WriteState::Ready { + offset, + size, + encrypted_size, + } => { + let Some(chunk) = chunks.next() else { + println!("no chunk"); + break; + }; + + match this.noise.write_message(chunk, &mut this.encrypt_buffer[offset + 2..]) { + Err(error) => { + tracing::error!(target: LOG_TARGET, ?error, "failed to encrypt message"); + return Poll::Ready(Err(io::ErrorKind::InvalidData.into())); + } + Ok(nwritten) => { + this.encrypt_buffer[offset + 0] = (nwritten >> 8) as u8; + this.encrypt_buffer[offset + 1] = (nwritten & 0xff) as u8; + + if let Some(next_chunk) = chunks.peek() { + if next_chunk.len() + NOISE_EXTRA_ENCRYPT_SPACE + 2 + <= this.encrypt_buffer[offset + nwritten + 2..].len() + { + this.write_state = WriteState::Ready { + offset: offset + nwritten + 2, + size: size + chunk.len(), + encrypted_size: encrypted_size + nwritten + 2, + }; + continue; + } + } + + this.write_state = WriteState::WriteFrame { + offset: 0usize, + size: size + chunk.len(), + encrypted_size: encrypted_size + nwritten + 2, + }; + } + } + } + WriteState::WriteFrame { + ref mut offset, + size, + encrypted_size, + } => loop { + match futures::ready!(Pin::new(&mut this.io) + .poll_write(cx, &this.encrypt_buffer[*offset..encrypted_size])) + { + Ok(nwritten) => { + *offset += nwritten; + + if offset == &encrypted_size { + this.write_state = WriteState::Ready { + offset: 0usize, + size: 0usize, + encrypted_size: 0usize, + }; + return Poll::Ready(Ok(size)); + } + } + Err(error) => return Poll::Ready(Err(error)), + } + }, + } + } + + Poll::Ready(Ok(0)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_close(cx) + } } /// Try to parse `PeerId` from received `NoiseHandshakePayload` fn parse_peer_id(buf: &[u8]) -> crate::Result { - match handshake_schema::NoiseHandshakePayload::decode(buf) { - Ok(payload) => { - let public_key = - PublicKey::from_protobuf_encoding(&payload.identity_key.ok_or( - error::Error::NegotiationError(error::NegotiationError::PeerIdMissing), - )?)?; - Ok(PeerId::from_public_key(&public_key)) - }, - Err(err) => Err(From::from(err)), - } + match handshake_schema::NoiseHandshakePayload::decode(buf) { + Ok(payload) => { + let public_key = PublicKey::from_protobuf_encoding(&payload.identity_key.ok_or( + error::Error::NegotiationError(error::NegotiationError::PeerIdMissing), + )?)?; + Ok(PeerId::from_public_key(&public_key)) + } + Err(err) => Err(From::from(err)), + } } /// Perform Noise handshake. pub async fn handshake( - mut io: S, - keypair: &Keypair, - role: Role, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, + mut io: S, + keypair: &Keypair, + role: Role, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, ) -> crate::Result<(NoiseSocket, PeerId)> { - tracing::debug!(target: LOG_TARGET, ?role, "start noise handshake"); - - let mut noise = NoiseContext::new(keypair, role); - let peer = match role { - Role::Dialer => { - // write initial message - let first_message = noise.first_message(Role::Dialer); - let _ = io.write(&first_message).await?; - let _ = io.flush().await?; - - // read back response which contains the remote peer id - let message = noise.read_handshake_message(&mut io).await?; - - // send the final message which contains local peer id - let second_message = noise.second_message(); - let _ = io.write(&second_message).await?; - let _ = io.flush().await?; - - parse_peer_id(&message)? - }, - Role::Listener => { - // read remote's first message - let _ = noise.read_handshake_message(&mut io).await?; - - // send local peer id. - let second_message = noise.second_message(); - let _ = io.write(&second_message).await?; - let _ = io.flush().await?; - - // read remote's second message which contains their peer id - let message = noise.read_handshake_message(&mut io).await?; - parse_peer_id(&message)? - }, - }; - - Ok(( - NoiseSocket::new(io, noise.into_transport(), max_read_ahead_factor, max_write_buffer_size), - peer, - )) + tracing::debug!(target: LOG_TARGET, ?role, "start noise handshake"); + + let mut noise = NoiseContext::new(keypair, role); + let peer = match role { + Role::Dialer => { + // write initial message + let first_message = noise.first_message(Role::Dialer); + let _ = io.write(&first_message).await?; + let _ = io.flush().await?; + + // read back response which contains the remote peer id + let message = noise.read_handshake_message(&mut io).await?; + + // send the final message which contains local peer id + let second_message = noise.second_message(); + let _ = io.write(&second_message).await?; + let _ = io.flush().await?; + + parse_peer_id(&message)? + } + Role::Listener => { + // read remote's first message + let _ = noise.read_handshake_message(&mut io).await?; + + // send local peer id. + let second_message = noise.second_message(); + let _ = io.write(&second_message).await?; + let _ = io.flush().await?; + + // read remote's second message which contains their peer id + let message = noise.read_handshake_message(&mut io).await?; + parse_peer_id(&message)? + } + }; + + Ok(( + NoiseSocket::new( + io, + noise.into_transport(), + max_read_ahead_factor, + max_write_buffer_size, + ), + peer, + )) } // TODO: add more tests #[cfg(test)] mod tests { - use super::*; - use std::net::SocketAddr; - use tokio::net::{TcpListener, TcpStream}; - use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; - - #[tokio::test] - async fn noise_handshake() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let keypair1 = Keypair::generate(); - let keypair2 = Keypair::generate(); - - let peer1_id = PeerId::from_public_key(&keypair1.public().into()); - let peer2_id = PeerId::from_public_key(&keypair2.public().into()); - - let listener = TcpListener::bind("[::1]:0".parse::().unwrap()).await.unwrap(); - - let (stream1, stream2) = - tokio::join!(TcpStream::connect(listener.local_addr().unwrap()), listener.accept()); - let (io1, io2) = { - let io1 = TokioAsyncReadCompatExt::compat(stream1.unwrap()).into_inner(); - let io1 = Box::new(TokioAsyncWriteCompatExt::compat_write(io1)); - let io2 = TokioAsyncReadCompatExt::compat(stream2.unwrap().0).into_inner(); - let io2 = Box::new(TokioAsyncWriteCompatExt::compat_write(io2)); - - (io1, io2) - }; - - let (res1, res2) = tokio::join!( - handshake(io1, &keypair1, Role::Dialer, MAX_READ_AHEAD_FACTOR, MAX_WRITE_BUFFER_SIZE), - handshake(io2, &keypair2, Role::Listener, MAX_READ_AHEAD_FACTOR, MAX_WRITE_BUFFER_SIZE) - ); - let (mut res1, mut res2) = (res1.unwrap(), res2.unwrap()); - - assert_eq!(res1.1, peer2_id); - assert_eq!(res2.1, peer1_id); - - // verify the connection works by reading a string - let mut buf = vec![0u8; 512]; - let sent = res1.0.write(b"hello, world").await.unwrap(); - res2.0.read_exact(&mut buf[..sent]).await.unwrap(); - - assert_eq!(std::str::from_utf8(&buf[..sent]), Ok("hello, world")); - } - - #[test] - fn invalid_peer_id_schema() { - match parse_peer_id(&vec![1, 2, 3, 4]).unwrap_err() { - crate::Error::ParseError(_) => {}, - _ => panic!("invalid error"), - } - } + use super::*; + use std::net::SocketAddr; + use tokio::net::{TcpListener, TcpStream}; + use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + + #[tokio::test] + async fn noise_handshake() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair1 = Keypair::generate(); + let keypair2 = Keypair::generate(); + + let peer1_id = PeerId::from_public_key(&keypair1.public().into()); + let peer2_id = PeerId::from_public_key(&keypair2.public().into()); + + let listener = TcpListener::bind("[::1]:0".parse::().unwrap()).await.unwrap(); + + let (stream1, stream2) = tokio::join!( + TcpStream::connect(listener.local_addr().unwrap()), + listener.accept() + ); + let (io1, io2) = { + let io1 = TokioAsyncReadCompatExt::compat(stream1.unwrap()).into_inner(); + let io1 = Box::new(TokioAsyncWriteCompatExt::compat_write(io1)); + let io2 = TokioAsyncReadCompatExt::compat(stream2.unwrap().0).into_inner(); + let io2 = Box::new(TokioAsyncWriteCompatExt::compat_write(io2)); + + (io1, io2) + }; + + let (res1, res2) = tokio::join!( + handshake( + io1, + &keypair1, + Role::Dialer, + MAX_READ_AHEAD_FACTOR, + MAX_WRITE_BUFFER_SIZE + ), + handshake( + io2, + &keypair2, + Role::Listener, + MAX_READ_AHEAD_FACTOR, + MAX_WRITE_BUFFER_SIZE + ) + ); + let (mut res1, mut res2) = (res1.unwrap(), res2.unwrap()); + + assert_eq!(res1.1, peer2_id); + assert_eq!(res2.1, peer1_id); + + // verify the connection works by reading a string + let mut buf = vec![0u8; 512]; + let sent = res1.0.write(b"hello, world").await.unwrap(); + res2.0.read_exact(&mut buf[..sent]).await.unwrap(); + + assert_eq!(std::str::from_utf8(&buf[..sent]), Ok("hello, world")); + } + + #[test] + fn invalid_peer_id_schema() { + match parse_peer_id(&vec![1, 2, 3, 4]).unwrap_err() { + crate::Error::ParseError(_) => {} + _ => panic!("invalid error"), + } + } } diff --git a/src/crypto/noise/protocol.rs b/src/crypto/noise/protocol.rs index 61ef21d8..38924880 100644 --- a/src/crypto/noise/protocol.rs +++ b/src/crypto/noise/protocol.rs @@ -26,18 +26,18 @@ use zeroize::Zeroize; /// DH keypair. #[derive(Clone)] pub struct Keypair { - pub secret: SecretKey, - pub public: PublicKey, + pub secret: SecretKey, + pub public: PublicKey, } /// The associated public identity of a DH keypair. #[derive(Clone)] pub struct KeypairIdentity { - /// The public identity key. - pub public: crypto::PublicKey, + /// The public identity key. + pub public: crypto::PublicKey, - /// The signature over the public DH key. - pub signature: Option>, + /// The signature over the public DH key. + pub signature: Option>, } /// DH secret key. @@ -45,15 +45,15 @@ pub struct KeypairIdentity { pub struct SecretKey(pub T); impl Drop for SecretKey { - fn drop(&mut self) { - self.0.zeroize() - } + fn drop(&mut self) { + self.0.zeroize() + } } impl + Zeroize> AsRef<[u8]> for SecretKey { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } } /// DH public key. @@ -61,17 +61,17 @@ impl + Zeroize> AsRef<[u8]> for SecretKey { pub struct PublicKey(pub T); impl> PartialEq for PublicKey { - fn eq(&self, other: &PublicKey) -> bool { - self.as_ref() == other.as_ref() - } + fn eq(&self, other: &PublicKey) -> bool { + self.as_ref() == other.as_ref() + } } impl> Eq for PublicKey {} impl> AsRef<[u8]> for PublicKey { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } } /// Custom `snow::CryptoResolver` which delegates to either the @@ -81,48 +81,48 @@ impl> AsRef<[u8]> for PublicKey { pub struct Resolver; impl snow::resolvers::CryptoResolver for Resolver { - fn resolve_rng(&self) -> Option> { - Some(Box::new(Rng(rand::rngs::StdRng::from_entropy()))) - } - - fn resolve_dh(&self, _: &snow::params::DHChoice) -> Option> { - Some(Box::new(Keypair::::default())) - } - - fn resolve_hash( - &self, - choice: &snow::params::HashChoice, - ) -> Option> { - snow::resolvers::RingResolver.resolve_hash(choice) - } - - fn resolve_cipher( - &self, - choice: &snow::params::CipherChoice, - ) -> Option> { - snow::resolvers::RingResolver.resolve_cipher(choice) - } + fn resolve_rng(&self) -> Option> { + Some(Box::new(Rng(rand::rngs::StdRng::from_entropy()))) + } + + fn resolve_dh(&self, _: &snow::params::DHChoice) -> Option> { + Some(Box::new(Keypair::::default())) + } + + fn resolve_hash( + &self, + choice: &snow::params::HashChoice, + ) -> Option> { + snow::resolvers::RingResolver.resolve_hash(choice) + } + + fn resolve_cipher( + &self, + choice: &snow::params::CipherChoice, + ) -> Option> { + snow::resolvers::RingResolver.resolve_cipher(choice) + } } /// Wrapper around a CSPRNG to implement `snow::Random` trait for. struct Rng(rand::rngs::StdRng); impl rand::RngCore for Rng { - fn next_u32(&mut self) -> u32 { - self.0.next_u32() - } + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } - fn next_u64(&mut self) -> u64 { - self.0.next_u64() - } + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest) - } + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest) + } - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { - self.0.try_fill_bytes(dest) - } + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + self.0.try_fill_bytes(dest) + } } impl rand::CryptoRng for Rng {} diff --git a/src/crypto/noise/x25519_spec.rs b/src/crypto/noise/x25519_spec.rs index 15b96ac5..2c87864d 100644 --- a/src/crypto/noise/x25519_spec.rs +++ b/src/crypto/noise/x25519_spec.rs @@ -29,89 +29,89 @@ use crate::crypto::noise::protocol::{Keypair, PublicKey, SecretKey}; pub struct X25519Spec([u8; 32]); impl AsRef<[u8]> for X25519Spec { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } } impl Zeroize for X25519Spec { - fn zeroize(&mut self) { - self.0.zeroize() - } + fn zeroize(&mut self) { + self.0.zeroize() + } } impl Keypair { - /// An "empty" keypair as a starting state for DH computations in `snow`, - /// which get manipulated through the `snow::types::Dh` interface. - pub(super) fn default() -> Self { - Keypair { - secret: SecretKey(X25519Spec([0u8; 32])), - public: PublicKey(X25519Spec([0u8; 32])), - } - } + /// An "empty" keypair as a starting state for DH computations in `snow`, + /// which get manipulated through the `snow::types::Dh` interface. + pub(super) fn default() -> Self { + Keypair { + secret: SecretKey(X25519Spec([0u8; 32])), + public: PublicKey(X25519Spec([0u8; 32])), + } + } - /// Create a new X25519 keypair. - pub fn new() -> Keypair { - let mut sk_bytes = [0u8; 32]; - rand::thread_rng().fill(&mut sk_bytes); - let sk = SecretKey(X25519Spec(sk_bytes)); // Copy - sk_bytes.zeroize(); - Self::from(sk) - } + /// Create a new X25519 keypair. + pub fn new() -> Keypair { + let mut sk_bytes = [0u8; 32]; + rand::thread_rng().fill(&mut sk_bytes); + let sk = SecretKey(X25519Spec(sk_bytes)); // Copy + sk_bytes.zeroize(); + Self::from(sk) + } } impl Default for Keypair { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } /// Promote a X25519 secret key into a keypair. impl From> for Keypair { - fn from(secret: SecretKey) -> Keypair { - let public = PublicKey(X25519Spec(x25519((secret.0).0, X25519_BASEPOINT_BYTES))); - Keypair { secret, public } - } + fn from(secret: SecretKey) -> Keypair { + let public = PublicKey(X25519Spec(x25519((secret.0).0, X25519_BASEPOINT_BYTES))); + Keypair { secret, public } + } } impl snow::types::Dh for Keypair { - fn name(&self) -> &'static str { - "25519" - } - fn pub_len(&self) -> usize { - 32 - } - fn priv_len(&self) -> usize { - 32 - } - fn pubkey(&self) -> &[u8] { - self.public.as_ref() - } - fn privkey(&self) -> &[u8] { - self.secret.as_ref() - } + fn name(&self) -> &'static str { + "25519" + } + fn pub_len(&self) -> usize { + 32 + } + fn priv_len(&self) -> usize { + 32 + } + fn pubkey(&self) -> &[u8] { + self.public.as_ref() + } + fn privkey(&self) -> &[u8] { + self.secret.as_ref() + } - fn set(&mut self, sk: &[u8]) { - let mut secret = [0u8; 32]; - secret.copy_from_slice(sk); - self.secret = SecretKey(X25519Spec(secret)); - self.public = PublicKey(X25519Spec(x25519(secret, X25519_BASEPOINT_BYTES))); - secret.zeroize(); - } + fn set(&mut self, sk: &[u8]) { + let mut secret = [0u8; 32]; + secret.copy_from_slice(sk); + self.secret = SecretKey(X25519Spec(secret)); + self.public = PublicKey(X25519Spec(x25519(secret, X25519_BASEPOINT_BYTES))); + secret.zeroize(); + } - fn generate(&mut self, rng: &mut dyn snow::types::Random) { - let mut secret = [0u8; 32]; - rng.fill_bytes(&mut secret); - self.secret = SecretKey(X25519Spec(secret)); - self.public = PublicKey(X25519Spec(x25519(secret, X25519_BASEPOINT_BYTES))); - secret.zeroize(); - } + fn generate(&mut self, rng: &mut dyn snow::types::Random) { + let mut secret = [0u8; 32]; + rng.fill_bytes(&mut secret); + self.secret = SecretKey(X25519Spec(secret)); + self.public = PublicKey(X25519Spec(x25519(secret, X25519_BASEPOINT_BYTES))); + secret.zeroize(); + } - fn dh(&self, pk: &[u8], shared_secret: &mut [u8]) -> Result<(), snow::Error> { - let mut p = [0; 32]; - p.copy_from_slice(&pk[..32]); - let ss = x25519((self.secret.0).0, p); - shared_secret[..32].copy_from_slice(&ss[..]); - Ok(()) - } + fn dh(&self, pk: &[u8], shared_secret: &mut [u8]) -> Result<(), snow::Error> { + let mut p = [0; 32]; + p.copy_from_slice(&pk[..32]); + let ss = x25519((self.secret.0).0, p); + shared_secret[..32].copy_from_slice(&ss[..]); + Ok(()) + } } diff --git a/src/crypto/tls/certificate.rs b/src/crypto/tls/certificate.rs index 53ceb015..9f5929d7 100644 --- a/src/crypto/tls/certificate.rs +++ b/src/crypto/tls/certificate.rs @@ -23,8 +23,8 @@ //! This module handles generation, signing, and verification of certificates. use crate::{ - crypto::{ed25519::Keypair, PublicKey}, - PeerId, + crypto::{ed25519::Keypair, PublicKey}, + PeerId, }; // use libp2p_identity as identity; @@ -50,30 +50,31 @@ static P2P_SIGNATURE_ALGORITHM: &rcgen::SignatureAlgorithm = &rcgen::PKCS_ECDSA_ /// Generates a self-signed TLS certificate that includes a libp2p-specific /// certificate extension containing the public key of the given keypair. pub fn generate( - identity_keypair: &Keypair, + identity_keypair: &Keypair, ) -> Result<(rustls::Certificate, rustls::PrivateKey), GenError> { - // Keypair used to sign the certificate. - // SHOULD NOT be related to the host's key. - // Endpoints MAY generate a new key and certificate - // for every connection attempt, or they MAY reuse the same key - // and certificate for multiple connections. - let certificate_keypair = rcgen::KeyPair::generate(P2P_SIGNATURE_ALGORITHM)?; - let rustls_key = rustls::PrivateKey(certificate_keypair.serialize_der()); - - let certificate = { - let mut params = rcgen::CertificateParams::new(vec![]); - params.distinguished_name = rcgen::DistinguishedName::new(); - params - .custom_extensions - .push(make_libp2p_extension(identity_keypair, &certificate_keypair)?); - params.alg = P2P_SIGNATURE_ALGORITHM; - params.key_pair = Some(certificate_keypair); - rcgen::Certificate::from_params(params)? - }; - - let rustls_certificate = rustls::Certificate(certificate.serialize_der()?); - - Ok((rustls_certificate, rustls_key)) + // Keypair used to sign the certificate. + // SHOULD NOT be related to the host's key. + // Endpoints MAY generate a new key and certificate + // for every connection attempt, or they MAY reuse the same key + // and certificate for multiple connections. + let certificate_keypair = rcgen::KeyPair::generate(P2P_SIGNATURE_ALGORITHM)?; + let rustls_key = rustls::PrivateKey(certificate_keypair.serialize_der()); + + let certificate = { + let mut params = rcgen::CertificateParams::new(vec![]); + params.distinguished_name = rcgen::DistinguishedName::new(); + params.custom_extensions.push(make_libp2p_extension( + identity_keypair, + &certificate_keypair, + )?); + params.alg = P2P_SIGNATURE_ALGORITHM; + params.key_pair = Some(certificate_keypair); + rcgen::Certificate::from_params(params)? + }; + + let rustls_certificate = rustls::Certificate(certificate.serialize_der()?); + + Ok((rustls_certificate, rustls_key)) } /// Attempts to parse the provided bytes as a [`P2pCertificate`]. @@ -81,30 +82,30 @@ pub fn generate( /// For this to succeed, the certificate must contain the specified extension and the signature must /// match the embedded public key. pub fn parse(certificate: &rustls::Certificate) -> Result, ParseError> { - let certificate = parse_unverified(certificate.as_ref())?; + let certificate = parse_unverified(certificate.as_ref())?; - certificate.verify()?; + certificate.verify()?; - Ok(certificate) + Ok(certificate) } /// An X.509 certificate with a libp2p-specific extension /// is used to secure libp2p connections. pub struct P2pCertificate<'a> { - certificate: X509Certificate<'a>, - /// This is a specific libp2p Public Key Extension with two values: - /// * the public host key - /// * a signature performed using the private host key - extension: P2pExtension, + certificate: X509Certificate<'a>, + /// This is a specific libp2p Public Key Extension with two values: + /// * the public host key + /// * a signature performed using the private host key + extension: P2pExtension, } /// The contents of the specific libp2p extension, containing the public host key /// and a signature performed using the private host key. pub struct P2pExtension { - public_key: PublicKey, - /// This signature provides cryptographic proof that the peer was - /// in possession of the private host key at the time the certificate was signed. - signature: Vec, + public_key: PublicKey, + /// This signature provides cryptographic proof that the peer was + /// in possession of the private host key at the time the certificate was signed. + signature: Vec, } #[derive(Debug, thiserror::Error)] @@ -123,400 +124,409 @@ pub struct VerificationError(#[from] pub(crate) webpki::Error); /// /// Useful for testing but unsuitable for production. fn parse_unverified(der_input: &[u8]) -> Result { - let x509 = X509Certificate::from_der(der_input) - .map(|(_rest_input, x509)| x509) - .map_err(|_| webpki::Error::BadDer)?; - - let p2p_ext_oid = der_parser::oid::Oid::from(&P2P_EXT_OID) - .expect("This is a valid OID of p2p extension; qed"); - - let mut libp2p_extension = None; - - for ext in x509.extensions() { - let oid = &ext.oid; - if oid == &p2p_ext_oid && libp2p_extension.is_some() { - // The extension was already parsed - return Err(webpki::Error::BadDer); - } - - if oid == &p2p_ext_oid { - // The public host key and the signature are ANS.1-encoded - // into the SignedKey data structure, which is carried - // in the libp2p Public Key Extension. - // SignedKey ::= SEQUENCE { - // publicKey OCTET STRING, - // signature OCTET STRING - // } - let (public_key, signature): (Vec, Vec) = - yasna::decode_der(ext.value).map_err(|_| webpki::Error::ExtensionValueInvalid)?; - // The publicKey field of SignedKey contains the public host key - // of the endpoint, encoded using the following protobuf: - // enum KeyType { - // RSA = 0; - // Ed25519 = 1; - // Secp256k1 = 2; - // ECDSA = 3; - // } - // message PublicKey { - // required KeyType Type = 1; - // required bytes Data = 2; - // } - let public_key = PublicKey::from_protobuf_encoding(&public_key) - .map_err(|_| webpki::Error::UnknownIssuer)?; - let ext = P2pExtension { public_key, signature }; - libp2p_extension = Some(ext); - continue; - } - - if ext.critical { - // Endpoints MUST abort the connection attempt if the certificate - // contains critical extensions that the endpoint does not understand. - return Err(webpki::Error::UnsupportedCriticalExtension); - } - - // Implementations MUST ignore non-critical extensions with unknown OIDs. - } - - // The certificate MUST contain the libp2p Public Key Extension. - // If this extension is missing, endpoints MUST abort the connection attempt. - let extension = libp2p_extension.ok_or(webpki::Error::BadDer)?; - - let certificate = P2pCertificate { certificate: x509, extension }; - - Ok(certificate) + let x509 = X509Certificate::from_der(der_input) + .map(|(_rest_input, x509)| x509) + .map_err(|_| webpki::Error::BadDer)?; + + let p2p_ext_oid = der_parser::oid::Oid::from(&P2P_EXT_OID) + .expect("This is a valid OID of p2p extension; qed"); + + let mut libp2p_extension = None; + + for ext in x509.extensions() { + let oid = &ext.oid; + if oid == &p2p_ext_oid && libp2p_extension.is_some() { + // The extension was already parsed + return Err(webpki::Error::BadDer); + } + + if oid == &p2p_ext_oid { + // The public host key and the signature are ANS.1-encoded + // into the SignedKey data structure, which is carried + // in the libp2p Public Key Extension. + // SignedKey ::= SEQUENCE { + // publicKey OCTET STRING, + // signature OCTET STRING + // } + let (public_key, signature): (Vec, Vec) = + yasna::decode_der(ext.value).map_err(|_| webpki::Error::ExtensionValueInvalid)?; + // The publicKey field of SignedKey contains the public host key + // of the endpoint, encoded using the following protobuf: + // enum KeyType { + // RSA = 0; + // Ed25519 = 1; + // Secp256k1 = 2; + // ECDSA = 3; + // } + // message PublicKey { + // required KeyType Type = 1; + // required bytes Data = 2; + // } + let public_key = PublicKey::from_protobuf_encoding(&public_key) + .map_err(|_| webpki::Error::UnknownIssuer)?; + let ext = P2pExtension { + public_key, + signature, + }; + libp2p_extension = Some(ext); + continue; + } + + if ext.critical { + // Endpoints MUST abort the connection attempt if the certificate + // contains critical extensions that the endpoint does not understand. + return Err(webpki::Error::UnsupportedCriticalExtension); + } + + // Implementations MUST ignore non-critical extensions with unknown OIDs. + } + + // The certificate MUST contain the libp2p Public Key Extension. + // If this extension is missing, endpoints MUST abort the connection attempt. + let extension = libp2p_extension.ok_or(webpki::Error::BadDer)?; + + let certificate = P2pCertificate { + certificate: x509, + extension, + }; + + Ok(certificate) } fn make_libp2p_extension( - identity_keypair: &Keypair, - certificate_keypair: &rcgen::KeyPair, + identity_keypair: &Keypair, + certificate_keypair: &rcgen::KeyPair, ) -> Result { - // The peer signs the concatenation of the string `libp2p-tls-handshake:` - // and the public key that it used to generate the certificate carrying - // the libp2p Public Key Extension, using its private host key. - let signature = { - let mut msg = vec![]; - msg.extend(P2P_SIGNING_PREFIX); - msg.extend(certificate_keypair.public_key_der()); - - identity_keypair.sign(&msg) - }; - - // The public host key and the signature are ANS.1-encoded - // into the SignedKey data structure, which is carried - // in the libp2p Public Key Extension. - // SignedKey ::= SEQUENCE { - // publicKey OCTET STRING, - // signature OCTET STRING - // } - let extension_content = { - // TODO: this is ridiculous - let serialized_pubkey = - crate::crypto::PublicKey::Ed25519(identity_keypair.public()).to_protobuf_encoding(); - yasna::encode_der(&(serialized_pubkey, signature)) - }; - - // This extension MAY be marked critical. - let mut ext = rcgen::CustomExtension::from_oid_content(&P2P_EXT_OID, extension_content); - ext.set_criticality(true); - - Ok(ext) + // The peer signs the concatenation of the string `libp2p-tls-handshake:` + // and the public key that it used to generate the certificate carrying + // the libp2p Public Key Extension, using its private host key. + let signature = { + let mut msg = vec![]; + msg.extend(P2P_SIGNING_PREFIX); + msg.extend(certificate_keypair.public_key_der()); + + identity_keypair.sign(&msg) + }; + + // The public host key and the signature are ANS.1-encoded + // into the SignedKey data structure, which is carried + // in the libp2p Public Key Extension. + // SignedKey ::= SEQUENCE { + // publicKey OCTET STRING, + // signature OCTET STRING + // } + let extension_content = { + // TODO: this is ridiculous + let serialized_pubkey = + crate::crypto::PublicKey::Ed25519(identity_keypair.public()).to_protobuf_encoding(); + yasna::encode_der(&(serialized_pubkey, signature)) + }; + + // This extension MAY be marked critical. + let mut ext = rcgen::CustomExtension::from_oid_content(&P2P_EXT_OID, extension_content); + ext.set_criticality(true); + + Ok(ext) } impl P2pCertificate<'_> { - /// The [`PeerId`] of the remote peer. - pub fn peer_id(&self) -> PeerId { - self.extension.public_key.to_peer_id() - } - - /// Verify the `signature` of the `message` signed by the private key corresponding to the - /// public key stored in the certificate. - pub fn verify_signature( - &self, - signature_scheme: rustls::SignatureScheme, - message: &[u8], - signature: &[u8], - ) -> Result<(), VerificationError> { - let pk = self.public_key(signature_scheme)?; - pk.verify(message, signature) - .map_err(|_| webpki::Error::InvalidSignatureForPublicKey)?; - - Ok(()) - } - - /// Get a [`ring::signature::UnparsedPublicKey`] for this `signature_scheme`. - /// Return `Error` if the `signature_scheme` does not match the public key signature - /// and hashing algorithm or if the `signature_scheme` is not supported. - fn public_key( - &self, - signature_scheme: rustls::SignatureScheme, - ) -> Result, webpki::Error> { - use ring::signature; - use rustls::SignatureScheme::*; - - let current_signature_scheme = self.signature_scheme()?; - if signature_scheme != current_signature_scheme { - // This certificate was signed with a different signature scheme - return Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey); - } - - let verification_algorithm: &dyn signature::VerificationAlgorithm = match signature_scheme { - RSA_PKCS1_SHA256 => &signature::RSA_PKCS1_2048_8192_SHA256, - RSA_PKCS1_SHA384 => &signature::RSA_PKCS1_2048_8192_SHA384, - RSA_PKCS1_SHA512 => &signature::RSA_PKCS1_2048_8192_SHA512, - ECDSA_NISTP256_SHA256 => &signature::ECDSA_P256_SHA256_ASN1, - ECDSA_NISTP384_SHA384 => &signature::ECDSA_P384_SHA384_ASN1, - ECDSA_NISTP521_SHA512 => { - // See https://github.com/briansmith/ring/issues/824 - return Err(webpki::Error::UnsupportedSignatureAlgorithm); - }, - RSA_PSS_SHA256 => &signature::RSA_PSS_2048_8192_SHA256, - RSA_PSS_SHA384 => &signature::RSA_PSS_2048_8192_SHA384, - RSA_PSS_SHA512 => &signature::RSA_PSS_2048_8192_SHA512, - ED25519 => &signature::ED25519, - ED448 => { - // See https://github.com/briansmith/ring/issues/463 - return Err(webpki::Error::UnsupportedSignatureAlgorithm); - }, - // Similarly, hash functions with an output length less than 256 bits - // MUST NOT be used, due to the possibility of collision attacks. - // In particular, MD5 and SHA1 MUST NOT be used. - RSA_PKCS1_SHA1 => return Err(webpki::Error::UnsupportedSignatureAlgorithm), - ECDSA_SHA1_Legacy => return Err(webpki::Error::UnsupportedSignatureAlgorithm), - Unknown(_) => return Err(webpki::Error::UnsupportedSignatureAlgorithm), - }; - let spki = &self.certificate.tbs_certificate.subject_pki; - let key = signature::UnparsedPublicKey::new( - verification_algorithm, - spki.subject_public_key.as_ref(), - ); - - Ok(key) - } - - /// This method validates the certificate according to libp2p TLS 1.3 specs. - /// The certificate MUST: - /// 1. be valid at the time it is received by the peer; - /// 2. use the NamedCurve encoding; - /// 3. use hash functions with an output length not less than 256 bits; - /// 4. be self signed; - /// 5. contain a valid signature in the specific libp2p extension. - fn verify(&self) -> Result<(), webpki::Error> { - use webpki::Error; - // The certificate MUST have NotBefore and NotAfter fields set - // such that the certificate is valid at the time it is received by the peer. - if !self.certificate.validity().is_valid() { - return Err(Error::InvalidCertValidity); - } - - // Certificates MUST use the NamedCurve encoding for elliptic curve parameters. - // Similarly, hash functions with an output length less than 256 bits - // MUST NOT be used, due to the possibility of collision attacks. - // In particular, MD5 and SHA1 MUST NOT be used. - // Endpoints MUST abort the connection attempt if it is not used. - let signature_scheme = self.signature_scheme()?; - // Endpoints MUST abort the connection attempt if the certificate’s - // self-signature is not valid. - let raw_certificate = self.certificate.tbs_certificate.as_ref(); - let signature = self.certificate.signature_value.as_ref(); - // check if self signed - self.verify_signature(signature_scheme, raw_certificate, signature) - .map_err(|_| Error::SignatureAlgorithmMismatch)?; - - let subject_pki = self.certificate.public_key().raw; - - // The peer signs the concatenation of the string `libp2p-tls-handshake:` - // and the public key that it used to generate the certificate carrying - // the libp2p Public Key Extension, using its private host key. - let mut msg = vec![]; - msg.extend(P2P_SIGNING_PREFIX); - msg.extend(subject_pki); - - // This signature provides cryptographic proof that the peer was in possession - // of the private host key at the time the certificate was signed. - // Peers MUST verify the signature, and abort the connection attempt - // if signature verification fails. - let user_owns_sk = self.extension.public_key.verify(&msg, &self.extension.signature); - if !user_owns_sk { - return Err(Error::UnknownIssuer); - } - - Ok(()) - } - - /// Return the signature scheme corresponding to [`AlgorithmIdentifier`]s - /// of `subject_pki` and `signature_algorithm` - /// according to . - fn signature_scheme(&self) -> Result { - // Certificates MUST use the NamedCurve encoding for elliptic curve parameters. - // Endpoints MUST abort the connection attempt if it is not used. - use oid_registry::*; - use rustls::SignatureScheme::*; - - let signature_algorithm = &self.certificate.signature_algorithm; - let pki_algorithm = &self.certificate.tbs_certificate.subject_pki.algorithm; - - if pki_algorithm.algorithm == OID_PKCS1_RSAENCRYPTION { - if signature_algorithm.algorithm == OID_PKCS1_SHA256WITHRSA { - return Ok(RSA_PKCS1_SHA256); - } - if signature_algorithm.algorithm == OID_PKCS1_SHA384WITHRSA { - return Ok(RSA_PKCS1_SHA384); - } - if signature_algorithm.algorithm == OID_PKCS1_SHA512WITHRSA { - return Ok(RSA_PKCS1_SHA512); - } - if signature_algorithm.algorithm == OID_PKCS1_RSASSAPSS { - // According to https://datatracker.ietf.org/doc/html/rfc4055#section-3.1: - // Inside of params there shuld be a sequence of: - // - Hash Algorithm - // - Mask Algorithm - // - Salt Length - // - Trailer Field - - // We are interested in Hash Algorithm only - - if let Ok(SignatureAlgorithm::RSASSA_PSS(params)) = - SignatureAlgorithm::try_from(signature_algorithm) - { - let hash_oid = params.hash_algorithm_oid(); - if hash_oid == &OID_NIST_HASH_SHA256 { - return Ok(RSA_PSS_SHA256); - } - if hash_oid == &OID_NIST_HASH_SHA384 { - return Ok(RSA_PSS_SHA384); - } - if hash_oid == &OID_NIST_HASH_SHA512 { - return Ok(RSA_PSS_SHA512); - } - } - - // Default hash algo is SHA-1, however: - // In particular, MD5 and SHA1 MUST NOT be used. - return Err(webpki::Error::UnsupportedSignatureAlgorithm); - } - } - - if pki_algorithm.algorithm == OID_KEY_TYPE_EC_PUBLIC_KEY { - let signature_param = pki_algorithm - .parameters - .as_ref() - .ok_or(webpki::Error::BadDer)? - .as_oid() - .map_err(|_| webpki::Error::BadDer)?; - if signature_param == OID_EC_P256 && - signature_algorithm.algorithm == OID_SIG_ECDSA_WITH_SHA256 - { - return Ok(ECDSA_NISTP256_SHA256); - } - if signature_param == OID_NIST_EC_P384 && - signature_algorithm.algorithm == OID_SIG_ECDSA_WITH_SHA384 - { - return Ok(ECDSA_NISTP384_SHA384); - } - if signature_param == OID_NIST_EC_P521 && - signature_algorithm.algorithm == OID_SIG_ECDSA_WITH_SHA512 - { - return Ok(ECDSA_NISTP521_SHA512); - } - return Err(webpki::Error::UnsupportedSignatureAlgorithm); - } - - if signature_algorithm.algorithm == OID_SIG_ED25519 { - return Ok(ED25519); - } - if signature_algorithm.algorithm == OID_SIG_ED448 { - return Ok(ED448); - } - - Err(webpki::Error::UnsupportedSignatureAlgorithm) - } + /// The [`PeerId`] of the remote peer. + pub fn peer_id(&self) -> PeerId { + self.extension.public_key.to_peer_id() + } + + /// Verify the `signature` of the `message` signed by the private key corresponding to the + /// public key stored in the certificate. + pub fn verify_signature( + &self, + signature_scheme: rustls::SignatureScheme, + message: &[u8], + signature: &[u8], + ) -> Result<(), VerificationError> { + let pk = self.public_key(signature_scheme)?; + pk.verify(message, signature) + .map_err(|_| webpki::Error::InvalidSignatureForPublicKey)?; + + Ok(()) + } + + /// Get a [`ring::signature::UnparsedPublicKey`] for this `signature_scheme`. + /// Return `Error` if the `signature_scheme` does not match the public key signature + /// and hashing algorithm or if the `signature_scheme` is not supported. + fn public_key( + &self, + signature_scheme: rustls::SignatureScheme, + ) -> Result, webpki::Error> { + use ring::signature; + use rustls::SignatureScheme::*; + + let current_signature_scheme = self.signature_scheme()?; + if signature_scheme != current_signature_scheme { + // This certificate was signed with a different signature scheme + return Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey); + } + + let verification_algorithm: &dyn signature::VerificationAlgorithm = match signature_scheme { + RSA_PKCS1_SHA256 => &signature::RSA_PKCS1_2048_8192_SHA256, + RSA_PKCS1_SHA384 => &signature::RSA_PKCS1_2048_8192_SHA384, + RSA_PKCS1_SHA512 => &signature::RSA_PKCS1_2048_8192_SHA512, + ECDSA_NISTP256_SHA256 => &signature::ECDSA_P256_SHA256_ASN1, + ECDSA_NISTP384_SHA384 => &signature::ECDSA_P384_SHA384_ASN1, + ECDSA_NISTP521_SHA512 => { + // See https://github.com/briansmith/ring/issues/824 + return Err(webpki::Error::UnsupportedSignatureAlgorithm); + } + RSA_PSS_SHA256 => &signature::RSA_PSS_2048_8192_SHA256, + RSA_PSS_SHA384 => &signature::RSA_PSS_2048_8192_SHA384, + RSA_PSS_SHA512 => &signature::RSA_PSS_2048_8192_SHA512, + ED25519 => &signature::ED25519, + ED448 => { + // See https://github.com/briansmith/ring/issues/463 + return Err(webpki::Error::UnsupportedSignatureAlgorithm); + } + // Similarly, hash functions with an output length less than 256 bits + // MUST NOT be used, due to the possibility of collision attacks. + // In particular, MD5 and SHA1 MUST NOT be used. + RSA_PKCS1_SHA1 => return Err(webpki::Error::UnsupportedSignatureAlgorithm), + ECDSA_SHA1_Legacy => return Err(webpki::Error::UnsupportedSignatureAlgorithm), + Unknown(_) => return Err(webpki::Error::UnsupportedSignatureAlgorithm), + }; + let spki = &self.certificate.tbs_certificate.subject_pki; + let key = signature::UnparsedPublicKey::new( + verification_algorithm, + spki.subject_public_key.as_ref(), + ); + + Ok(key) + } + + /// This method validates the certificate according to libp2p TLS 1.3 specs. + /// The certificate MUST: + /// 1. be valid at the time it is received by the peer; + /// 2. use the NamedCurve encoding; + /// 3. use hash functions with an output length not less than 256 bits; + /// 4. be self signed; + /// 5. contain a valid signature in the specific libp2p extension. + fn verify(&self) -> Result<(), webpki::Error> { + use webpki::Error; + // The certificate MUST have NotBefore and NotAfter fields set + // such that the certificate is valid at the time it is received by the peer. + if !self.certificate.validity().is_valid() { + return Err(Error::InvalidCertValidity); + } + + // Certificates MUST use the NamedCurve encoding for elliptic curve parameters. + // Similarly, hash functions with an output length less than 256 bits + // MUST NOT be used, due to the possibility of collision attacks. + // In particular, MD5 and SHA1 MUST NOT be used. + // Endpoints MUST abort the connection attempt if it is not used. + let signature_scheme = self.signature_scheme()?; + // Endpoints MUST abort the connection attempt if the certificate’s + // self-signature is not valid. + let raw_certificate = self.certificate.tbs_certificate.as_ref(); + let signature = self.certificate.signature_value.as_ref(); + // check if self signed + self.verify_signature(signature_scheme, raw_certificate, signature) + .map_err(|_| Error::SignatureAlgorithmMismatch)?; + + let subject_pki = self.certificate.public_key().raw; + + // The peer signs the concatenation of the string `libp2p-tls-handshake:` + // and the public key that it used to generate the certificate carrying + // the libp2p Public Key Extension, using its private host key. + let mut msg = vec![]; + msg.extend(P2P_SIGNING_PREFIX); + msg.extend(subject_pki); + + // This signature provides cryptographic proof that the peer was in possession + // of the private host key at the time the certificate was signed. + // Peers MUST verify the signature, and abort the connection attempt + // if signature verification fails. + let user_owns_sk = self.extension.public_key.verify(&msg, &self.extension.signature); + if !user_owns_sk { + return Err(Error::UnknownIssuer); + } + + Ok(()) + } + + /// Return the signature scheme corresponding to [`AlgorithmIdentifier`]s + /// of `subject_pki` and `signature_algorithm` + /// according to . + fn signature_scheme(&self) -> Result { + // Certificates MUST use the NamedCurve encoding for elliptic curve parameters. + // Endpoints MUST abort the connection attempt if it is not used. + use oid_registry::*; + use rustls::SignatureScheme::*; + + let signature_algorithm = &self.certificate.signature_algorithm; + let pki_algorithm = &self.certificate.tbs_certificate.subject_pki.algorithm; + + if pki_algorithm.algorithm == OID_PKCS1_RSAENCRYPTION { + if signature_algorithm.algorithm == OID_PKCS1_SHA256WITHRSA { + return Ok(RSA_PKCS1_SHA256); + } + if signature_algorithm.algorithm == OID_PKCS1_SHA384WITHRSA { + return Ok(RSA_PKCS1_SHA384); + } + if signature_algorithm.algorithm == OID_PKCS1_SHA512WITHRSA { + return Ok(RSA_PKCS1_SHA512); + } + if signature_algorithm.algorithm == OID_PKCS1_RSASSAPSS { + // According to https://datatracker.ietf.org/doc/html/rfc4055#section-3.1: + // Inside of params there shuld be a sequence of: + // - Hash Algorithm + // - Mask Algorithm + // - Salt Length + // - Trailer Field + + // We are interested in Hash Algorithm only + + if let Ok(SignatureAlgorithm::RSASSA_PSS(params)) = + SignatureAlgorithm::try_from(signature_algorithm) + { + let hash_oid = params.hash_algorithm_oid(); + if hash_oid == &OID_NIST_HASH_SHA256 { + return Ok(RSA_PSS_SHA256); + } + if hash_oid == &OID_NIST_HASH_SHA384 { + return Ok(RSA_PSS_SHA384); + } + if hash_oid == &OID_NIST_HASH_SHA512 { + return Ok(RSA_PSS_SHA512); + } + } + + // Default hash algo is SHA-1, however: + // In particular, MD5 and SHA1 MUST NOT be used. + return Err(webpki::Error::UnsupportedSignatureAlgorithm); + } + } + + if pki_algorithm.algorithm == OID_KEY_TYPE_EC_PUBLIC_KEY { + let signature_param = pki_algorithm + .parameters + .as_ref() + .ok_or(webpki::Error::BadDer)? + .as_oid() + .map_err(|_| webpki::Error::BadDer)?; + if signature_param == OID_EC_P256 + && signature_algorithm.algorithm == OID_SIG_ECDSA_WITH_SHA256 + { + return Ok(ECDSA_NISTP256_SHA256); + } + if signature_param == OID_NIST_EC_P384 + && signature_algorithm.algorithm == OID_SIG_ECDSA_WITH_SHA384 + { + return Ok(ECDSA_NISTP384_SHA384); + } + if signature_param == OID_NIST_EC_P521 + && signature_algorithm.algorithm == OID_SIG_ECDSA_WITH_SHA512 + { + return Ok(ECDSA_NISTP521_SHA512); + } + return Err(webpki::Error::UnsupportedSignatureAlgorithm); + } + + if signature_algorithm.algorithm == OID_SIG_ED25519 { + return Ok(ED25519); + } + if signature_algorithm.algorithm == OID_SIG_ED448 { + return Ok(ED448); + } + + Err(webpki::Error::UnsupportedSignatureAlgorithm) + } } #[cfg(test)] mod tests { - use super::*; - use hex_literal::hex; - - #[test] - fn sanity_check() { - // let keypair = identity::Keypair::generate_ed25519(); - let keypair = crate::crypto::ed25519::Keypair::generate(); - - let (cert, _) = generate(&keypair).unwrap(); - let parsed_cert = parse(&cert).unwrap(); - - assert!(parsed_cert.verify().is_ok()); - assert_eq!( - crate::crypto::PublicKey::Ed25519(keypair.public()), - parsed_cert.extension.public_key - ); - } - - macro_rules! check_cert { - ($name:ident, $path:literal, $scheme:path) => { - #[test] - fn $name() { - let cert: &[u8] = include_bytes!($path); - - let cert = parse_unverified(cert).unwrap(); - assert!(cert.verify().is_err()); // Because p2p extension - // was not signed with the private key - // of the certificate. - assert_eq!(cert.signature_scheme(), Ok($scheme)); - } - }; - } - - check_cert! {ed448, "./test_assets/ed448.der", rustls::SignatureScheme::ED448} - check_cert! {ed25519, "./test_assets/ed25519.der", rustls::SignatureScheme::ED25519} - check_cert! {rsa_pkcs1_sha256, "./test_assets/rsa_pkcs1_sha256.der", rustls::SignatureScheme::RSA_PKCS1_SHA256} - check_cert! {rsa_pkcs1_sha384, "./test_assets/rsa_pkcs1_sha384.der", rustls::SignatureScheme::RSA_PKCS1_SHA384} - check_cert! {rsa_pkcs1_sha512, "./test_assets/rsa_pkcs1_sha512.der", rustls::SignatureScheme::RSA_PKCS1_SHA512} - check_cert! {nistp256_sha256, "./test_assets/nistp256_sha256.der", rustls::SignatureScheme::ECDSA_NISTP256_SHA256} - check_cert! {nistp384_sha384, "./test_assets/nistp384_sha384.der", rustls::SignatureScheme::ECDSA_NISTP384_SHA384} - check_cert! {nistp521_sha512, "./test_assets/nistp521_sha512.der", rustls::SignatureScheme::ECDSA_NISTP521_SHA512} - - #[test] - fn rsa_pss_sha384() { - let cert = rustls::Certificate(include_bytes!("./test_assets/rsa_pss_sha384.der").to_vec()); - - let cert = parse(&cert).unwrap(); - - assert_eq!(cert.signature_scheme(), Ok(rustls::SignatureScheme::RSA_PSS_SHA384)); - } - - #[test] - fn nistp384_sha256() { - let cert: &[u8] = include_bytes!("./test_assets/nistp384_sha256.der"); - - let cert = parse_unverified(cert).unwrap(); - - assert!(cert.signature_scheme().is_err()); - } - - #[test] - fn can_parse_certificate_with_ed25519_keypair() { - let certificate = rustls::Certificate(hex!("308201773082011ea003020102020900f5bd0debaa597f52300a06082a8648ce3d04030230003020170d3735303130313030303030305a180f34303936303130313030303030305a30003059301306072a8648ce3d020106082a8648ce3d030107034200046bf9871220d71dcb3483ecdfcbfcc7c103f8509d0974b3c18ab1f1be1302d643103a08f7a7722c1b247ba3876fe2c59e26526f479d7718a85202ddbe47562358a37f307d307b060a2b0601040183a25a01010101ff046a30680424080112207fda21856709c5ae12fd6e8450623f15f11955d384212b89f56e7e136d2e17280440aaa6bffabe91b6f30c35e3aa4f94b1188fed96b0ffdd393f4c58c1c047854120e674ce64c788406d1c2c4b116581fd7411b309881c3c7f20b46e54c7e6fe7f0f300a06082a8648ce3d040302034700304402207d1a1dbd2bda235ff2ec87daf006f9b04ba076a5a5530180cd9c2e8f6399e09d0220458527178c7e77024601dbb1b256593e9b96d961b96349d1f560114f61a87595").to_vec()); - - let peer_id = parse(&certificate).unwrap().peer_id(); - - assert_eq!( - "12D3KooWJRSrypvnpHgc6ZAgyCni4KcSmbV7uGRaMw5LgMKT18fq" - .parse::() - .unwrap(), - peer_id - ); - } - - #[test] - fn fails_to_parse_bad_certificate_with_ed25519_keypair() { - let certificate = rustls::Certificate(hex!("308201773082011da003020102020830a73c5d896a1109300a06082a8648ce3d04030230003020170d3735303130313030303030305a180f34303936303130313030303030305a30003059301306072a8648ce3d020106082a8648ce3d03010703420004bbe62df9a7c1c46b7f1f21d556deec5382a36df146fb29c7f1240e60d7d5328570e3b71d99602b77a65c9b3655f62837f8d66b59f1763b8c9beba3be07778043a37f307d307b060a2b0601040183a25a01010101ff046a3068042408011220ec8094573afb9728088860864f7bcea2d4fd412fef09a8e2d24d482377c20db60440ecabae8354afa2f0af4b8d2ad871e865cb5a7c0c8d3dbdbf42de577f92461a0ebb0a28703e33581af7d2a4f2270fc37aec6261fcc95f8af08f3f4806581c730a300a06082a8648ce3d040302034800304502202dfb17a6fa0f94ee0e2e6a3b9fb6e986f311dee27392058016464bd130930a61022100ba4b937a11c8d3172b81e7cd04aedb79b978c4379c2b5b24d565dd5d67d3cb3c").to_vec()); - - match parse(&certificate) { - Ok(_) => assert!(false), - Err(error) => { - assert_eq!(format!("{error}"), "UnknownIssuer"); - }, - } - } + use super::*; + use hex_literal::hex; + + #[test] + fn sanity_check() { + // let keypair = identity::Keypair::generate_ed25519(); + let keypair = crate::crypto::ed25519::Keypair::generate(); + + let (cert, _) = generate(&keypair).unwrap(); + let parsed_cert = parse(&cert).unwrap(); + + assert!(parsed_cert.verify().is_ok()); + assert_eq!( + crate::crypto::PublicKey::Ed25519(keypair.public()), + parsed_cert.extension.public_key + ); + } + + macro_rules! check_cert { + ($name:ident, $path:literal, $scheme:path) => { + #[test] + fn $name() { + let cert: &[u8] = include_bytes!($path); + + let cert = parse_unverified(cert).unwrap(); + assert!(cert.verify().is_err()); // Because p2p extension + // was not signed with the private key + // of the certificate. + assert_eq!(cert.signature_scheme(), Ok($scheme)); + } + }; + } + + check_cert! {ed448, "./test_assets/ed448.der", rustls::SignatureScheme::ED448} + check_cert! {ed25519, "./test_assets/ed25519.der", rustls::SignatureScheme::ED25519} + check_cert! {rsa_pkcs1_sha256, "./test_assets/rsa_pkcs1_sha256.der", rustls::SignatureScheme::RSA_PKCS1_SHA256} + check_cert! {rsa_pkcs1_sha384, "./test_assets/rsa_pkcs1_sha384.der", rustls::SignatureScheme::RSA_PKCS1_SHA384} + check_cert! {rsa_pkcs1_sha512, "./test_assets/rsa_pkcs1_sha512.der", rustls::SignatureScheme::RSA_PKCS1_SHA512} + check_cert! {nistp256_sha256, "./test_assets/nistp256_sha256.der", rustls::SignatureScheme::ECDSA_NISTP256_SHA256} + check_cert! {nistp384_sha384, "./test_assets/nistp384_sha384.der", rustls::SignatureScheme::ECDSA_NISTP384_SHA384} + check_cert! {nistp521_sha512, "./test_assets/nistp521_sha512.der", rustls::SignatureScheme::ECDSA_NISTP521_SHA512} + + #[test] + fn rsa_pss_sha384() { + let cert = rustls::Certificate(include_bytes!("./test_assets/rsa_pss_sha384.der").to_vec()); + + let cert = parse(&cert).unwrap(); + + assert_eq!( + cert.signature_scheme(), + Ok(rustls::SignatureScheme::RSA_PSS_SHA384) + ); + } + + #[test] + fn nistp384_sha256() { + let cert: &[u8] = include_bytes!("./test_assets/nistp384_sha256.der"); + + let cert = parse_unverified(cert).unwrap(); + + assert!(cert.signature_scheme().is_err()); + } + + #[test] + fn can_parse_certificate_with_ed25519_keypair() { + let certificate = rustls::Certificate(hex!("308201773082011ea003020102020900f5bd0debaa597f52300a06082a8648ce3d04030230003020170d3735303130313030303030305a180f34303936303130313030303030305a30003059301306072a8648ce3d020106082a8648ce3d030107034200046bf9871220d71dcb3483ecdfcbfcc7c103f8509d0974b3c18ab1f1be1302d643103a08f7a7722c1b247ba3876fe2c59e26526f479d7718a85202ddbe47562358a37f307d307b060a2b0601040183a25a01010101ff046a30680424080112207fda21856709c5ae12fd6e8450623f15f11955d384212b89f56e7e136d2e17280440aaa6bffabe91b6f30c35e3aa4f94b1188fed96b0ffdd393f4c58c1c047854120e674ce64c788406d1c2c4b116581fd7411b309881c3c7f20b46e54c7e6fe7f0f300a06082a8648ce3d040302034700304402207d1a1dbd2bda235ff2ec87daf006f9b04ba076a5a5530180cd9c2e8f6399e09d0220458527178c7e77024601dbb1b256593e9b96d961b96349d1f560114f61a87595").to_vec()); + + let peer_id = parse(&certificate).unwrap().peer_id(); + + assert_eq!( + "12D3KooWJRSrypvnpHgc6ZAgyCni4KcSmbV7uGRaMw5LgMKT18fq" + .parse::() + .unwrap(), + peer_id + ); + } + + #[test] + fn fails_to_parse_bad_certificate_with_ed25519_keypair() { + let certificate = rustls::Certificate(hex!("308201773082011da003020102020830a73c5d896a1109300a06082a8648ce3d04030230003020170d3735303130313030303030305a180f34303936303130313030303030305a30003059301306072a8648ce3d020106082a8648ce3d03010703420004bbe62df9a7c1c46b7f1f21d556deec5382a36df146fb29c7f1240e60d7d5328570e3b71d99602b77a65c9b3655f62837f8d66b59f1763b8c9beba3be07778043a37f307d307b060a2b0601040183a25a01010101ff046a3068042408011220ec8094573afb9728088860864f7bcea2d4fd412fef09a8e2d24d482377c20db60440ecabae8354afa2f0af4b8d2ad871e865cb5a7c0c8d3dbdbf42de577f92461a0ebb0a28703e33581af7d2a4f2270fc37aec6261fcc95f8af08f3f4806581c730a300a06082a8648ce3d040302034800304502202dfb17a6fa0f94ee0e2e6a3b9fb6e986f311dee27392058016464bd130930a61022100ba4b937a11c8d3172b81e7cd04aedb79b978c4379c2b5b24d565dd5d67d3cb3c").to_vec()); + + match parse(&certificate) { + Ok(_) => assert!(false), + Err(error) => { + assert_eq!(format!("{error}"), "UnknownIssuer"); + } + } + } } diff --git a/src/crypto/tls/mod.rs b/src/crypto/tls/mod.rs index f304bbef..e19976ae 100644 --- a/src/crypto/tls/mod.rs +++ b/src/crypto/tls/mod.rs @@ -36,41 +36,41 @@ const P2P_ALPN: [u8; 6] = *b"libp2p"; /// Create a TLS server configuration for litep2p. pub fn make_server_config( - keypair: &Keypair, + keypair: &Keypair, ) -> Result { - let (certificate, private_key) = certificate::generate(keypair)?; + let (certificate, private_key) = certificate::generate(keypair)?; - let mut crypto = rustls::ServerConfig::builder() - .with_cipher_suites(verifier::CIPHERSUITES) - .with_safe_default_kx_groups() - .with_protocol_versions(verifier::PROTOCOL_VERSIONS) - .expect("Cipher suites and kx groups are configured; qed") - .with_client_cert_verifier(Arc::new(verifier::Libp2pCertificateVerifier::new())) - .with_single_cert(vec![certificate], private_key) - .expect("Server cert key DER is valid; qed"); - crypto.alpn_protocols = vec![P2P_ALPN.to_vec()]; + let mut crypto = rustls::ServerConfig::builder() + .with_cipher_suites(verifier::CIPHERSUITES) + .with_safe_default_kx_groups() + .with_protocol_versions(verifier::PROTOCOL_VERSIONS) + .expect("Cipher suites and kx groups are configured; qed") + .with_client_cert_verifier(Arc::new(verifier::Libp2pCertificateVerifier::new())) + .with_single_cert(vec![certificate], private_key) + .expect("Server cert key DER is valid; qed"); + crypto.alpn_protocols = vec![P2P_ALPN.to_vec()]; - Ok(crypto) + Ok(crypto) } /// Create a TLS client configuration for libp2p. pub fn make_client_config( - keypair: &Keypair, - remote_peer_id: Option, + keypair: &Keypair, + remote_peer_id: Option, ) -> Result { - let (certificate, private_key) = certificate::generate(keypair)?; + let (certificate, private_key) = certificate::generate(keypair)?; - let mut crypto = rustls::ClientConfig::builder() - .with_cipher_suites(verifier::CIPHERSUITES) - .with_safe_default_kx_groups() - .with_protocol_versions(verifier::PROTOCOL_VERSIONS) - .expect("Cipher suites and kx groups are configured; qed") - .with_custom_certificate_verifier(Arc::new( - verifier::Libp2pCertificateVerifier::with_remote_peer_id(remote_peer_id), - )) - .with_single_cert(vec![certificate], private_key) - .expect("Client cert key DER is valid; qed"); - crypto.alpn_protocols = vec![P2P_ALPN.to_vec()]; + let mut crypto = rustls::ClientConfig::builder() + .with_cipher_suites(verifier::CIPHERSUITES) + .with_safe_default_kx_groups() + .with_protocol_versions(verifier::PROTOCOL_VERSIONS) + .expect("Cipher suites and kx groups are configured; qed") + .with_custom_certificate_verifier(Arc::new( + verifier::Libp2pCertificateVerifier::with_remote_peer_id(remote_peer_id), + )) + .with_single_cert(vec![certificate], private_key) + .expect("Client cert key DER is valid; qed"); + crypto.alpn_protocols = vec![P2P_ALPN.to_vec()]; - Ok(crypto) + Ok(crypto) } diff --git a/src/crypto/tls/verifier.rs b/src/crypto/tls/verifier.rs index f809756c..470c43c2 100644 --- a/src/crypto/tls/verifier.rs +++ b/src/crypto/tls/verifier.rs @@ -26,14 +26,14 @@ use crate::{crypto::tls::certificate, PeerId}; use rustls::{ - cipher_suite::{ - TLS13_AES_128_GCM_SHA256, TLS13_AES_256_GCM_SHA384, TLS13_CHACHA20_POLY1305_SHA256, - }, - client::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, - internal::msgs::handshake::DigitallySignedStruct, - server::{ClientCertVerified, ClientCertVerifier}, - Certificate, DistinguishedNames, SignatureScheme, SupportedCipherSuite, - SupportedProtocolVersion, + cipher_suite::{ + TLS13_AES_128_GCM_SHA256, TLS13_AES_256_GCM_SHA384, TLS13_CHACHA20_POLY1305_SHA256, + }, + client::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, + internal::msgs::handshake::DigitallySignedStruct, + server::{ClientCertVerified, ClientCertVerifier}, + Certificate, DistinguishedNames, SignatureScheme, SupportedCipherSuite, + SupportedProtocolVersion, }; /// The protocol versions supported by this verifier. @@ -48,18 +48,18 @@ pub static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version: // By default rustls creates client/server configs with both // TLS 1.3 __and__ 1.2 cipher suites. But we don't need 1.2. pub static CIPHERSUITES: &[SupportedCipherSuite] = &[ - // TLS1.3 suites - TLS13_CHACHA20_POLY1305_SHA256, - TLS13_AES_256_GCM_SHA384, - TLS13_AES_128_GCM_SHA256, + // TLS1.3 suites + TLS13_CHACHA20_POLY1305_SHA256, + TLS13_AES_256_GCM_SHA384, + TLS13_AES_128_GCM_SHA256, ]; /// Implementation of the `rustls` certificate verification traits for libp2p. /// /// Only TLS 1.3 is supported. TLS 1.2 should be disabled in the configuration of `rustls`. pub struct Libp2pCertificateVerifier { - /// The peer ID we intend to connect to - remote_peer_id: Option, + /// The peer ID we intend to connect to + remote_peer_id: Option, } /// libp2p requires the following of X.509 server certificate chains: @@ -69,85 +69,87 @@ pub struct Libp2pCertificateVerifier { /// - The certificate must have a valid libp2p extension that includes a signature of its public /// key. impl Libp2pCertificateVerifier { - pub fn new() -> Self { - Self { remote_peer_id: None } - } - - pub fn with_remote_peer_id(remote_peer_id: Option) -> Self { - Self { remote_peer_id } - } - - /// Return the list of SignatureSchemes that this verifier will handle, - /// in `verify_tls12_signature` and `verify_tls13_signature` calls. - /// - /// This should be in priority order, with the most preferred first. - fn verification_schemes() -> Vec { - vec![ - // TODO SignatureScheme::ECDSA_NISTP521_SHA512 is not supported by `ring` yet - SignatureScheme::ECDSA_NISTP384_SHA384, - SignatureScheme::ECDSA_NISTP256_SHA256, - // TODO SignatureScheme::ED448 is not supported by `ring` yet - SignatureScheme::ED25519, - // In particular, RSA SHOULD NOT be used unless - // no elliptic curve algorithms are supported. - SignatureScheme::RSA_PSS_SHA512, - SignatureScheme::RSA_PSS_SHA384, - SignatureScheme::RSA_PSS_SHA256, - SignatureScheme::RSA_PKCS1_SHA512, - SignatureScheme::RSA_PKCS1_SHA384, - SignatureScheme::RSA_PKCS1_SHA256, - ] - } + pub fn new() -> Self { + Self { + remote_peer_id: None, + } + } + + pub fn with_remote_peer_id(remote_peer_id: Option) -> Self { + Self { remote_peer_id } + } + + /// Return the list of SignatureSchemes that this verifier will handle, + /// in `verify_tls12_signature` and `verify_tls13_signature` calls. + /// + /// This should be in priority order, with the most preferred first. + fn verification_schemes() -> Vec { + vec![ + // TODO SignatureScheme::ECDSA_NISTP521_SHA512 is not supported by `ring` yet + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::ECDSA_NISTP256_SHA256, + // TODO SignatureScheme::ED448 is not supported by `ring` yet + SignatureScheme::ED25519, + // In particular, RSA SHOULD NOT be used unless + // no elliptic curve algorithms are supported. + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PKCS1_SHA512, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::RSA_PKCS1_SHA256, + ] + } } impl ServerCertVerifier for Libp2pCertificateVerifier { - fn verify_server_cert( - &self, - end_entity: &Certificate, - intermediates: &[Certificate], - _server_name: &rustls::ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: std::time::SystemTime, - ) -> Result { - let peer_id = verify_presented_certs(end_entity, intermediates)?; - - if let Some(remote_peer_id) = self.remote_peer_id { - // The public host key allows the peer to calculate the peer ID of the peer - // it is connecting to. Clients MUST verify that the peer ID derived from - // the certificate matches the peer ID they intended to connect to, - // and MUST abort the connection if there is a mismatch. - if remote_peer_id != peer_id { - return Err(rustls::Error::PeerMisbehavedError( - "Wrong peer ID in p2p extension".to_string(), - )); - } - } - - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &Certificate, - _dss: &DigitallySignedStruct, - ) -> Result { - unreachable!("`PROTOCOL_VERSIONS` only allows TLS 1.3") - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &Certificate, - dss: &DigitallySignedStruct, - ) -> Result { - verify_tls13_signature(cert, dss.scheme, message, dss.signature()) - } - - fn supported_verify_schemes(&self) -> Vec { - Self::verification_schemes() - } + fn verify_server_cert( + &self, + end_entity: &Certificate, + intermediates: &[Certificate], + _server_name: &rustls::ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result { + let peer_id = verify_presented_certs(end_entity, intermediates)?; + + if let Some(remote_peer_id) = self.remote_peer_id { + // The public host key allows the peer to calculate the peer ID of the peer + // it is connecting to. Clients MUST verify that the peer ID derived from + // the certificate matches the peer ID they intended to connect to, + // and MUST abort the connection if there is a mismatch. + if remote_peer_id != peer_id { + return Err(rustls::Error::PeerMisbehavedError( + "Wrong peer ID in p2p extension".to_string(), + )); + } + } + + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + unreachable!("`PROTOCOL_VERSIONS` only allows TLS 1.3") + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls13_signature(cert, dss.scheme, message, dss.signature()) + } + + fn supported_verify_schemes(&self) -> Vec { + Self::verification_schemes() + } } /// libp2p requires the following of X.509 client certificate chains: @@ -158,46 +160,46 @@ impl ServerCertVerifier for Libp2pCertificateVerifier { /// - The certificate must have a valid libp2p extension that includes a signature of its public /// key. impl ClientCertVerifier for Libp2pCertificateVerifier { - fn offer_client_auth(&self) -> bool { - true - } - - fn client_auth_root_subjects(&self) -> Option { - Some(vec![]) - } - - fn verify_client_cert( - &self, - end_entity: &Certificate, - intermediates: &[Certificate], - _now: std::time::SystemTime, - ) -> Result { - let _: PeerId = verify_presented_certs(end_entity, intermediates)?; - - Ok(ClientCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &Certificate, - _dss: &DigitallySignedStruct, - ) -> Result { - unreachable!("`PROTOCOL_VERSIONS` only allows TLS 1.3") - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &Certificate, - dss: &DigitallySignedStruct, - ) -> Result { - verify_tls13_signature(cert, dss.scheme, message, dss.signature()) - } - - fn supported_verify_schemes(&self) -> Vec { - Self::verification_schemes() - } + fn offer_client_auth(&self) -> bool { + true + } + + fn client_auth_root_subjects(&self) -> Option { + Some(vec![]) + } + + fn verify_client_cert( + &self, + end_entity: &Certificate, + intermediates: &[Certificate], + _now: std::time::SystemTime, + ) -> Result { + let _: PeerId = verify_presented_certs(end_entity, intermediates)?; + + Ok(ClientCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + unreachable!("`PROTOCOL_VERSIONS` only allows TLS 1.3") + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &Certificate, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls13_signature(cert, dss.scheme, message, dss.signature()) + } + + fn supported_verify_schemes(&self) -> Vec { + Self::verification_schemes() + } } /// When receiving the certificate chain, an endpoint @@ -207,46 +209,48 @@ impl ClientCertVerifier for Libp2pCertificateVerifier { /// Endpoints MUST abort the connection attempt if more than one certificate is received, /// or if the certificate’s self-signature is not valid. fn verify_presented_certs( - end_entity: &Certificate, - intermediates: &[Certificate], + end_entity: &Certificate, + intermediates: &[Certificate], ) -> Result { - if !intermediates.is_empty() { - return Err(rustls::Error::General("libp2p-tls requires exactly one certificate".into())); - } + if !intermediates.is_empty() { + return Err(rustls::Error::General( + "libp2p-tls requires exactly one certificate".into(), + )); + } - let cert = certificate::parse(end_entity)?; + let cert = certificate::parse(end_entity)?; - Ok(cert.peer_id()) + Ok(cert.peer_id()) } fn verify_tls13_signature( - cert: &Certificate, - signature_scheme: SignatureScheme, - message: &[u8], - signature: &[u8], + cert: &Certificate, + signature_scheme: SignatureScheme, + message: &[u8], + signature: &[u8], ) -> Result { - certificate::parse(cert)?.verify_signature(signature_scheme, message, signature)?; + certificate::parse(cert)?.verify_signature(signature_scheme, message, signature)?; - Ok(HandshakeSignatureValid::assertion()) + Ok(HandshakeSignatureValid::assertion()) } impl From for rustls::Error { - fn from(certificate::ParseError(e): certificate::ParseError) -> Self { - use webpki::Error::*; - match e { - BadDer => rustls::Error::InvalidCertificateEncoding, - e => rustls::Error::InvalidCertificateData(format!("invalid peer certificate: {e}")), - } - } + fn from(certificate::ParseError(e): certificate::ParseError) -> Self { + use webpki::Error::*; + match e { + BadDer => rustls::Error::InvalidCertificateEncoding, + e => rustls::Error::InvalidCertificateData(format!("invalid peer certificate: {e}")), + } + } } impl From for rustls::Error { - fn from(certificate::VerificationError(e): certificate::VerificationError) -> Self { - use webpki::Error::*; - match e { - InvalidSignatureForPublicKey => rustls::Error::InvalidCertificateSignature, - UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => - rustls::Error::InvalidCertificateSignatureType, - e => rustls::Error::InvalidCertificateData(format!("invalid peer certificate: {e}")), - } - } + fn from(certificate::VerificationError(e): certificate::VerificationError) -> Self { + use webpki::Error::*; + match e { + InvalidSignatureForPublicKey => rustls::Error::InvalidCertificateSignature, + UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => + rustls::Error::InvalidCertificateSignatureType, + e => rustls::Error::InvalidCertificateData(format!("invalid peer certificate: {e}")), + } + } } diff --git a/src/error.rs b/src/error.rs index 201a2d78..f0aae799 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,9 +27,9 @@ // TODO: move `NegotiationError` under `SubstreamError` use crate::{ - protocol::Direction, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - PeerId, + protocol::Direction, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + PeerId, }; use multiaddr::Multiaddr; @@ -39,224 +39,230 @@ use std::io::{self, ErrorKind}; #[derive(Debug, thiserror::Error)] pub enum Error { - #[error("Peer `{0}` does not exist")] - PeerDoesntExist(PeerId), - #[error("Peer `{0}` already exists")] - PeerAlreadyExists(PeerId), - #[error("Protocol `{0}` not supported")] - ProtocolNotSupported(String), - #[error("Address error: `{0}`")] - AddressError(AddressError), - #[error("Parse error: `{0}`")] - ParseError(ParseError), - #[error("I/O error: `{0}`")] - IoError(ErrorKind), - #[error("Negotiation error: `{0}`")] - NegotiationError(NegotiationError), - #[error("Substream error: `{0}`")] - SubstreamError(SubstreamError), - #[error("Substream error: `{0}`")] - NotificationError(NotificationError), - #[error("Essential task closed")] - EssentialTaskClosed, - #[error("Unknown error occurred")] - Unknown, - #[error("Cannot dial self: `{0}`")] - CannotDialSelf(Multiaddr), - #[error("Transport not supported")] - TransportNotSupported(Multiaddr), - #[error("Yamux error for substream `{0:?}`: `{1}`")] - YamuxError(Direction, crate::yamux::ConnectionError), - #[error("Operation not supported: `{0}`")] - NotSupported(String), - #[error("Other error occurred: `{0}`")] - Other(String), - #[error("Protocol already exists: `{0:?}`")] - ProtocolAlreadyExists(ProtocolName), - #[error("Operation timed out")] - Timeout, - #[error("Invalid state transition")] - InvalidState, - #[error("DNS address resolution failed")] - DnsAddressResolutionFailed, - #[error("Transport error: `{0}`")] - TransportError(String), - #[error("Failed to generate certificate: `{0}`")] - CertificateGeneration(#[from] crate::crypto::tls::certificate::GenError), - #[error("Invalid data")] - InvalidData, - #[error("Input rejected")] - InputRejected, - #[error("WebSocket error: `{0}`")] - WebSocket(#[from] tokio_tungstenite::tungstenite::error::Error), - #[error("Insufficient peers")] - InsufficientPeers, - #[error("Substream doens't exist")] - SubstreamDoesntExist, - #[error("`str0m` error: `{0}`")] - WebRtc(#[from] str0m::RtcError), - #[error("Remote peer disconnected")] - Disconnected, - #[error("Channel does not exist")] - ChannelDoesntExist, - #[error("Tried to dial self")] - TriedToDialSelf, - #[error("Litep2p is already connected to the peer")] - AlreadyConnected, - #[error("No addres available for `{0}`")] - NoAddressAvailable(PeerId), - #[error("Connection closed")] - ConnectionClosed, - #[error("Quinn error: `{0}`")] - Quinn(quinn::ConnectionError), - #[error("Invalid certificate")] - InvalidCertificate, - #[error("Peer ID mismatch: expected `{0}`, got `{1}`")] - PeerIdMismatch(PeerId, PeerId), - #[error("Channel is clogged")] - ChannelClogged, - #[error("Connection doesn't exist: `{0:?}`")] - ConnectionDoesntExist(ConnectionId), + #[error("Peer `{0}` does not exist")] + PeerDoesntExist(PeerId), + #[error("Peer `{0}` already exists")] + PeerAlreadyExists(PeerId), + #[error("Protocol `{0}` not supported")] + ProtocolNotSupported(String), + #[error("Address error: `{0}`")] + AddressError(AddressError), + #[error("Parse error: `{0}`")] + ParseError(ParseError), + #[error("I/O error: `{0}`")] + IoError(ErrorKind), + #[error("Negotiation error: `{0}`")] + NegotiationError(NegotiationError), + #[error("Substream error: `{0}`")] + SubstreamError(SubstreamError), + #[error("Substream error: `{0}`")] + NotificationError(NotificationError), + #[error("Essential task closed")] + EssentialTaskClosed, + #[error("Unknown error occurred")] + Unknown, + #[error("Cannot dial self: `{0}`")] + CannotDialSelf(Multiaddr), + #[error("Transport not supported")] + TransportNotSupported(Multiaddr), + #[error("Yamux error for substream `{0:?}`: `{1}`")] + YamuxError(Direction, crate::yamux::ConnectionError), + #[error("Operation not supported: `{0}`")] + NotSupported(String), + #[error("Other error occurred: `{0}`")] + Other(String), + #[error("Protocol already exists: `{0:?}`")] + ProtocolAlreadyExists(ProtocolName), + #[error("Operation timed out")] + Timeout, + #[error("Invalid state transition")] + InvalidState, + #[error("DNS address resolution failed")] + DnsAddressResolutionFailed, + #[error("Transport error: `{0}`")] + TransportError(String), + #[error("Failed to generate certificate: `{0}`")] + CertificateGeneration(#[from] crate::crypto::tls::certificate::GenError), + #[error("Invalid data")] + InvalidData, + #[error("Input rejected")] + InputRejected, + #[error("WebSocket error: `{0}`")] + WebSocket(#[from] tokio_tungstenite::tungstenite::error::Error), + #[error("Insufficient peers")] + InsufficientPeers, + #[error("Substream doens't exist")] + SubstreamDoesntExist, + #[error("`str0m` error: `{0}`")] + WebRtc(#[from] str0m::RtcError), + #[error("Remote peer disconnected")] + Disconnected, + #[error("Channel does not exist")] + ChannelDoesntExist, + #[error("Tried to dial self")] + TriedToDialSelf, + #[error("Litep2p is already connected to the peer")] + AlreadyConnected, + #[error("No addres available for `{0}`")] + NoAddressAvailable(PeerId), + #[error("Connection closed")] + ConnectionClosed, + #[error("Quinn error: `{0}`")] + Quinn(quinn::ConnectionError), + #[error("Invalid certificate")] + InvalidCertificate, + #[error("Peer ID mismatch: expected `{0}`, got `{1}`")] + PeerIdMismatch(PeerId, PeerId), + #[error("Channel is clogged")] + ChannelClogged, + #[error("Connection doesn't exist: `{0:?}`")] + ConnectionDoesntExist(ConnectionId), } #[derive(Debug, thiserror::Error)] pub enum AddressError { - #[error("Invalid protocol")] - InvalidProtocol, - #[error("`PeerId` missing from the address")] - PeerIdMissing, - #[error("Address not available")] - AddressNotAvailable, + #[error("Invalid protocol")] + InvalidProtocol, + #[error("`PeerId` missing from the address")] + PeerIdMissing, + #[error("Address not available")] + AddressNotAvailable, } #[derive(Debug, thiserror::Error)] pub enum ParseError { - #[error("Invalid multihash: `{0:?}`")] - InvalidMultihash(Multihash), - #[error("Failed to decode protobuf message: `{0:?}`")] - ProstDecodeError(prost::DecodeError), + #[error("Invalid multihash: `{0:?}`")] + InvalidMultihash(Multihash), + #[error("Failed to decode protobuf message: `{0:?}`")] + ProstDecodeError(prost::DecodeError), } #[derive(Debug, thiserror::Error)] pub enum SubstreamError { - #[error("Connection closed")] - ConnectionClosed, - #[error("yamux error: `{0}`")] - YamuxError(crate::yamux::ConnectionError), - #[error("Failed to read from substream, substream id `{0:?}`")] - ReadFailure(Option), - #[error("Failed to write to substream, substream id `{0:?}`")] - WriteFailure(Option), + #[error("Connection closed")] + ConnectionClosed, + #[error("yamux error: `{0}`")] + YamuxError(crate::yamux::ConnectionError), + #[error("Failed to read from substream, substream id `{0:?}`")] + ReadFailure(Option), + #[error("Failed to write to substream, substream id `{0:?}`")] + WriteFailure(Option), } #[derive(Debug, thiserror::Error)] pub enum NegotiationError { - #[error("multistream-select error: `{0:?}`")] - MultistreamSelectError(crate::multistream_select::NegotiationError), - #[error("multistream-select error: `{0:?}`")] - SnowError(snow::Error), - #[error("Connection closed while negotiating")] - ConnectionClosed, - #[error("`PeerId` missing from Noise handshake")] - PeerIdMissing, + #[error("multistream-select error: `{0:?}`")] + MultistreamSelectError(crate::multistream_select::NegotiationError), + #[error("multistream-select error: `{0:?}`")] + SnowError(snow::Error), + #[error("Connection closed while negotiating")] + ConnectionClosed, + #[error("`PeerId` missing from Noise handshake")] + PeerIdMissing, } #[derive(Debug, thiserror::Error)] pub enum NotificationError { - #[error("Peer already exists")] - PeerAlreadyExists, - #[error("Peer is in invalid state")] - InvalidState, - #[error("Notifications clogged")] - NotificationsClogged, - #[error("Notification stream closed")] - NotificationStreamClosed(PeerId), + #[error("Peer already exists")] + PeerAlreadyExists, + #[error("Peer is in invalid state")] + InvalidState, + #[error("Notifications clogged")] + NotificationsClogged, + #[error("Notification stream closed")] + NotificationStreamClosed(PeerId), } #[derive(Debug, thiserror::Error)] pub enum DialError { - #[error("Tried to dial self")] - TriedToDialSelf, - #[error("Already connected to peer")] - AlreadyConnected, - #[error("Peer doens't have any known addresses")] - NoAddressAvailable(PeerId), + #[error("Tried to dial self")] + TriedToDialSelf, + #[error("Already connected to peer")] + AlreadyConnected, + #[error("Peer doens't have any known addresses")] + NoAddressAvailable(PeerId), } impl From> for Error { - fn from(hash: MultihashGeneric<64>) -> Self { - Error::ParseError(ParseError::InvalidMultihash(hash)) - } + fn from(hash: MultihashGeneric<64>) -> Self { + Error::ParseError(ParseError::InvalidMultihash(hash)) + } } impl From for Error { - fn from(error: io::Error) -> Error { - Error::IoError(error.kind()) - } + fn from(error: io::Error) -> Error { + Error::IoError(error.kind()) + } } impl From for Error { - fn from(error: crate::multistream_select::NegotiationError) -> Error { - Error::NegotiationError(NegotiationError::MultistreamSelectError(error)) - } + fn from(error: crate::multistream_select::NegotiationError) -> Error { + Error::NegotiationError(NegotiationError::MultistreamSelectError(error)) + } } impl From for Error { - fn from(error: snow::Error) -> Self { - Error::NegotiationError(NegotiationError::SnowError(error)) - } + fn from(error: snow::Error) -> Self { + Error::NegotiationError(NegotiationError::SnowError(error)) + } } impl From> for Error { - fn from(_: tokio::sync::mpsc::error::SendError) -> Self { - Error::EssentialTaskClosed - } + fn from(_: tokio::sync::mpsc::error::SendError) -> Self { + Error::EssentialTaskClosed + } } impl From for Error { - fn from(_: tokio::sync::oneshot::error::RecvError) -> Self { - Error::EssentialTaskClosed - } + fn from(_: tokio::sync::oneshot::error::RecvError) -> Self { + Error::EssentialTaskClosed + } } impl From for Error { - fn from(error: prost::DecodeError) -> Self { - Error::ParseError(ParseError::ProstDecodeError(error)) - } + fn from(error: prost::DecodeError) -> Self { + Error::ParseError(ParseError::ProstDecodeError(error)) + } } impl From for Error { - fn from(error: quinn::ConnectionError) -> Self { - match error { - quinn::ConnectionError::TimedOut => Error::Timeout, - error => Error::Quinn(error), - } - } + fn from(error: quinn::ConnectionError) -> Self { + match error { + quinn::ConnectionError::TimedOut => Error::Timeout, + error => Error::Quinn(error), + } + } } #[cfg(test)] mod tests { - use super::*; - use tokio::sync::mpsc::{channel, Sender}; + use super::*; + use tokio::sync::mpsc::{channel, Sender}; - #[tokio::test] - async fn try_from_errors() { - tracing::trace!("{:?}", NotificationError::InvalidState); - tracing::trace!("{:?}", DialError::AlreadyConnected); - tracing::trace!("{:?}", SubstreamError::YamuxError(crate::yamux::ConnectionError::Closed)); - tracing::trace!("{:?}", AddressError::PeerIdMissing); - tracing::trace!("{:?}", ParseError::InvalidMultihash(Multihash::from(PeerId::random()))); + #[tokio::test] + async fn try_from_errors() { + tracing::trace!("{:?}", NotificationError::InvalidState); + tracing::trace!("{:?}", DialError::AlreadyConnected); + tracing::trace!( + "{:?}", + SubstreamError::YamuxError(crate::yamux::ConnectionError::Closed) + ); + tracing::trace!("{:?}", AddressError::PeerIdMissing); + tracing::trace!( + "{:?}", + ParseError::InvalidMultihash(Multihash::from(PeerId::random())) + ); - let (tx, rx) = channel(1); - drop(rx); + let (tx, rx) = channel(1); + drop(rx); - async fn test(tx: Sender<()>) -> crate::Result<()> { - tx.send(()).await.map_err(From::from) - } + async fn test(tx: Sender<()>) -> crate::Result<()> { + tx.send(()).await.map_err(From::from) + } - match test(tx).await.unwrap_err() { - Error::EssentialTaskClosed => {}, - _ => panic!("invalid error"), - } - } + match test(tx).await.unwrap_err() { + Error::EssentialTaskClosed => {} + _ => panic!("invalid error"), + } + } } diff --git a/src/executor.rs b/src/executor.rs index 56d0aa50..d5d7bdd8 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -24,49 +24,49 @@ use std::{future::Future, pin::Pin}; /// Trait which defines the interface the executor must implement. pub trait Executor: Send + Sync { - /// Start executing a future in the background. - fn run(&self, future: Pin + Send>>); + /// Start executing a future in the background. + fn run(&self, future: Pin + Send>>); - /// Start executing a future in the background and give the future a name; - fn run_with_name(&self, name: &'static str, future: Pin + Send>>); + /// Start executing a future in the background and give the future a name; + fn run_with_name(&self, name: &'static str, future: Pin + Send>>); } /// Default executor, defaults to calling `tokio::spawn()`. pub(crate) struct DefaultExecutor; impl Executor for DefaultExecutor { - fn run(&self, future: Pin + Send>>) { - let _ = tokio::spawn(future); - } + fn run(&self, future: Pin + Send>>) { + let _ = tokio::spawn(future); + } - fn run_with_name(&self, _: &'static str, future: Pin + Send>>) { - let _ = tokio::spawn(future); - } + fn run_with_name(&self, _: &'static str, future: Pin + Send>>) { + let _ = tokio::spawn(future); + } } #[cfg(test)] mod tests { - use super::*; - use tokio::sync::mpsc::channel; + use super::*; + use tokio::sync::mpsc::channel; - #[tokio::test] - async fn run_with_name() { - let executor = DefaultExecutor; - let (tx, mut rx) = channel(1); + #[tokio::test] + async fn run_with_name() { + let executor = DefaultExecutor; + let (tx, mut rx) = channel(1); - let sender = tx.clone(); - executor.run(Box::pin(async move { - sender.send(1337usize).await.unwrap(); - })); + let sender = tx.clone(); + executor.run(Box::pin(async move { + sender.send(1337usize).await.unwrap(); + })); - executor.run_with_name( - "test", - Box::pin(async move { - tx.send(1337usize).await.unwrap(); - }), - ); + executor.run_with_name( + "test", + Box::pin(async move { + tx.send(1337usize).await.unwrap(); + }), + ); - assert_eq!(rx.recv().await.unwrap(), 1337usize); - assert_eq!(rx.recv().await.unwrap(), 1337usize); - } + assert_eq!(rx.recv().await.unwrap(), 1337usize); + assert_eq!(rx.recv().await.unwrap(), 1337usize); + } } diff --git a/src/lib.rs b/src/lib.rs index 7106ece9..a5cc9e47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,21 +19,21 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - config::Litep2pConfig, - protocol::{ - libp2p::{bitswap::Bitswap, identify::Identify, kademlia::Kademlia, ping::Ping}, - mdns::Mdns, - notification::NotificationProtocol, - request_response::RequestResponseProtocol, - }, - transport::{ - manager::{SupportedTransport, TransportManager}, - quic::QuicTransport, - tcp::TcpTransport, - webrtc::WebRtcTransport, - websocket::WebSocketTransport, - TransportBuilder, TransportEvent, - }, + config::Litep2pConfig, + protocol::{ + libp2p::{bitswap::Bitswap, identify::Identify, kademlia::Kademlia, ping::Ping}, + mdns::Mdns, + notification::NotificationProtocol, + request_response::RequestResponseProtocol, + }, + transport::{ + manager::{SupportedTransport, TransportManager}, + quic::QuicTransport, + tcp::TcpTransport, + webrtc::WebRtcTransport, + websocket::WebSocketTransport, + TransportBuilder, TransportEvent, + }, }; use multiaddr::{Multiaddr, Protocol}; @@ -79,523 +79,536 @@ const DEFAULT_CHANNEL_SIZE: usize = 4096usize; /// Litep2p events. #[derive(Debug)] pub enum Litep2pEvent { - /// Connection established to peer. - ConnectionEstablished { - /// Remote peer ID. - peer: PeerId, - - /// Endpoint. - endpoint: Endpoint, - }, - - /// Connection closed to remote peer. - ConnectionClosed { - /// Peer ID. - peer: PeerId, - - /// Connection ID. - connection_id: ConnectionId, - }, - - /// Failed to dial peer. - DialFailure { - /// Address of the peer. - address: Multiaddr, - - /// Dial error. - error: Error, - }, + /// Connection established to peer. + ConnectionEstablished { + /// Remote peer ID. + peer: PeerId, + + /// Endpoint. + endpoint: Endpoint, + }, + + /// Connection closed to remote peer. + ConnectionClosed { + /// Peer ID. + peer: PeerId, + + /// Connection ID. + connection_id: ConnectionId, + }, + + /// Failed to dial peer. + DialFailure { + /// Address of the peer. + address: Multiaddr, + + /// Dial error. + error: Error, + }, } /// [`Litep2p`] object. pub struct Litep2p { - /// Local peer ID. - local_peer_id: PeerId, + /// Local peer ID. + local_peer_id: PeerId, - /// Listen addresses. - listen_addresses: Vec, + /// Listen addresses. + listen_addresses: Vec, - /// Transport manager. - transport_manager: TransportManager, + /// Transport manager. + transport_manager: TransportManager, - /// Bandwidth sink. - bandwidth_sink: BandwidthSink, + /// Bandwidth sink. + bandwidth_sink: BandwidthSink, } impl Litep2p { - /// Create new [`Litep2p`]. - pub fn new(mut litep2p_config: Litep2pConfig) -> crate::Result { - let local_peer_id = PeerId::from_public_key(&litep2p_config.keypair.public().into()); - let bandwidth_sink = BandwidthSink::new(); - let mut listen_addresses = vec![]; - - let supported_transports = Self::supported_transports(&litep2p_config); - let (mut transport_manager, transport_handle) = TransportManager::new( - litep2p_config.keypair.clone(), - supported_transports, - bandwidth_sink.clone(), - litep2p_config.max_parallel_dials, - ); - - // add known addresses to `TransportManager`, if any exist - if !litep2p_config.known_addresses.is_empty() { - for (peer, addresses) in litep2p_config.known_addresses { - transport_manager.add_known_address(peer, addresses.iter().cloned()); - } - } - - // start notification protocol event loops - for (protocol, config) in litep2p_config.notification_protocols.into_iter() { - tracing::debug!( - target: LOG_TARGET, - ?protocol, - "enable notification protocol", - ); - - let service = transport_manager.register_protocol( - protocol, - config.fallback_names.clone(), - config.codec, - ); - let executor = Arc::clone(&litep2p_config.executor); - litep2p_config.executor.run(Box::pin(async move { - NotificationProtocol::new(service, config, executor).run().await - })); - } - - // start request-response protocol event loops - for (protocol, config) in litep2p_config.request_response_protocols.into_iter() { - tracing::debug!( - target: LOG_TARGET, - ?protocol, - "enable request-response protocol", - ); - - let service = transport_manager.register_protocol( - protocol, - config.fallback_names.clone(), - config.codec, - ); - litep2p_config.executor.run(Box::pin(async move { - RequestResponseProtocol::new(service, config).run().await - })); - } - - // start user protocol event loops - for (protocol_name, protocol) in litep2p_config.user_protocols.into_iter() { - tracing::debug!(target: LOG_TARGET, protocol = ?protocol_name, "enable user protocol"); - - let service = - transport_manager.register_protocol(protocol_name, Vec::new(), protocol.codec()); - litep2p_config.executor.run(Box::pin(async move { - let _ = protocol.run(service).await; - })); - } - - // start ping protocol event loop if enabled - if let Some(ping_config) = litep2p_config.ping.take() { - tracing::debug!( - target: LOG_TARGET, - protocol = ?ping_config.protocol, - "enable ipfs ping protocol", - ); - - let service = transport_manager.register_protocol( - ping_config.protocol.clone(), - Vec::new(), - ping_config.codec, - ); - litep2p_config - .executor - .run(Box::pin(async move { Ping::new(service, ping_config).run().await })); - } - - // start kademlia protocol event loop if enabled - if let Some(kademlia_config) = litep2p_config.kademlia.take() { - tracing::debug!( - target: LOG_TARGET, - protocol_names = ?kademlia_config.protocol_names, - "enable ipfs kademlia protocol", - ); - - let main_protocol = - kademlia_config.protocol_names.get(0).expect("protocol name to exist"); - let fallback_names = kademlia_config.protocol_names.iter().skip(1).cloned().collect(); - - let service = transport_manager.register_protocol( - main_protocol.clone(), - fallback_names, - kademlia_config.codec, - ); - litep2p_config.executor.run(Box::pin(async move { - let _ = Kademlia::new(service, kademlia_config).run().await; - })); - } - - // start identify protocol event loop if enabled - let mut identify_info = match litep2p_config.identify.take() { - None => None, - Some(mut identify_config) => { - tracing::debug!( - target: LOG_TARGET, - protocol = ?identify_config.protocol, - "enable ipfs identify protocol", - ); - - let service = transport_manager.register_protocol( - identify_config.protocol.clone(), - Vec::new(), - identify_config.codec.clone(), - ); - identify_config.public = Some(litep2p_config.keypair.public().into()); - - Some((service, identify_config)) - }, - }; - - // start bitswap protocol event loop if enabled - if let Some(bitswap_config) = litep2p_config.bitswap.take() { - tracing::debug!( - target: LOG_TARGET, - protocol = ?bitswap_config.protocol, - "enable ipfs bitswap protocol", - ); - - let service = transport_manager.register_protocol( - bitswap_config.protocol.clone(), - Vec::new(), - bitswap_config.codec, - ); - litep2p_config - .executor - .run(Box::pin(async move { Bitswap::new(service, bitswap_config).run().await })); - } - - // enable tcp transport if the config exists - if let Some(config) = litep2p_config.tcp.take() { - let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); - let (transport, transport_listen_addresses) = - ::new(handle, config)?; - - for address in transport_listen_addresses { - transport_manager.register_listen_address(address.clone()); - listen_addresses.push(address.with(Protocol::P2p( - Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), - ))); - } - - transport_manager.register_transport(SupportedTransport::Tcp, Box::new(transport)); - } - - // enable quic transport if the config exists - if let Some(config) = litep2p_config.quic.take() { - let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); - let (transport, transport_listen_addresses) = - ::new(handle, config)?; - - for address in transport_listen_addresses { - transport_manager.register_listen_address(address.clone()); - listen_addresses.push(address.with(Protocol::P2p( - Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), - ))); - } - - transport_manager.register_transport(SupportedTransport::Quic, Box::new(transport)); - } - - // enable webrtc transport if the config exists - if let Some(config) = litep2p_config.webrtc.take() { - let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); - let (transport, transport_listen_addresses) = - ::new(handle, config)?; - - for address in transport_listen_addresses { - transport_manager.register_listen_address(address.clone()); - listen_addresses.push(address.with(Protocol::P2p( - Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), - ))); - } - - transport_manager.register_transport(SupportedTransport::WebRtc, Box::new(transport)); - } - - // enable websocket transport if the config exists - if let Some(config) = litep2p_config.websocket.take() { - let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); - let (transport, transport_listen_addresses) = - ::new(handle, config)?; - - for address in transport_listen_addresses { - transport_manager.register_listen_address(address.clone()); - listen_addresses.push(address.with(Protocol::P2p( - Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), - ))); - } - - transport_manager - .register_transport(SupportedTransport::WebSocket, Box::new(transport)); - } - - // enable mdns if the config exists - if let Some(config) = litep2p_config.mdns.take() { - let mdns = Mdns::new(transport_handle, config, listen_addresses.clone())?; - - litep2p_config.executor.run(Box::pin(async move { - let _ = mdns.start().await; - })); - } - - // if identify was enabled, give it the enabled protocols and listen addresses and start it - if let Some((service, mut identify_config)) = identify_info.take() { - identify_config.protocols = transport_manager.protocols().cloned().collect(); - let identify = Identify::new(service, identify_config, listen_addresses.clone()); - - litep2p_config.executor.run(Box::pin(async move { - let _ = identify.run().await; - })); - } - - if transport_manager.installed_transports().count() == 0 { - return Err(Error::Other("No transport specified".to_string())); - } - - // verify that at least one transport is specified - if listen_addresses.is_empty() { - tracing::warn!( - target: LOG_TARGET, - "litep2p started with no listen addresses, cannot accept inbound connections", - ); - } - - Ok(Self { local_peer_id, bandwidth_sink, listen_addresses, transport_manager }) - } - - /// Collect supported transports before initializing the transports themselves. - /// - /// Information of the supported transports is needed to initialize protocols but - /// information about protocols must be known to initialize transports so the initialization - /// has to be split. - fn supported_transports(config: &Litep2pConfig) -> HashSet { - let mut supported_transports = HashSet::new(); - - config - .tcp - .is_some() - .then(|| supported_transports.insert(SupportedTransport::Tcp)); - config - .quic - .is_some() - .then(|| supported_transports.insert(SupportedTransport::Quic)); - config - .websocket - .is_some() - .then(|| supported_transports.insert(SupportedTransport::WebSocket)); - config - .webrtc - .is_some() - .then(|| supported_transports.insert(SupportedTransport::WebRtc)); - - supported_transports - } - - /// Get local peer ID. - pub fn local_peer_id(&self) -> &PeerId { - &self.local_peer_id - } - - /// Get listen address of litep2p. - pub fn listen_addresses(&self) -> impl Iterator { - self.listen_addresses.iter() - } - - /// Get handle to bandwidth sink. - pub fn bandwidth_sink(&self) -> BandwidthSink { - self.bandwidth_sink.clone() - } - - /// Dial peer. - pub async fn dial(&mut self, peer: &PeerId) -> crate::Result<()> { - self.transport_manager.dial(*peer).await - } - - /// Dial address. - pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { - self.transport_manager.dial_address(address).await - } - - /// Add one ore more known addresses for peer. - /// - /// Return value denotes how many addresses were added for the peer. - // Addresses belonging to disabled/unsupported transports will be ignored. - pub fn add_known_address( - &mut self, - peer: PeerId, - address: impl Iterator, - ) -> usize { - self.transport_manager.add_known_address(peer, address) - } - - /// Poll next event. - /// - /// This function must be called in order for litep2p to make progress. - pub async fn next_event(&mut self) -> Option { - loop { - match self.transport_manager.next().await? { - TransportEvent::ConnectionEstablished { peer, endpoint, .. } => - return Some(Litep2pEvent::ConnectionEstablished { peer, endpoint }), - TransportEvent::ConnectionClosed { peer, connection_id } => - return Some(Litep2pEvent::ConnectionClosed { peer, connection_id }), - TransportEvent::DialFailure { address, error, .. } => - return Some(Litep2pEvent::DialFailure { address, error }), - _ => {}, - } - } - } + /// Create new [`Litep2p`]. + pub fn new(mut litep2p_config: Litep2pConfig) -> crate::Result { + let local_peer_id = PeerId::from_public_key(&litep2p_config.keypair.public().into()); + let bandwidth_sink = BandwidthSink::new(); + let mut listen_addresses = vec![]; + + let supported_transports = Self::supported_transports(&litep2p_config); + let (mut transport_manager, transport_handle) = TransportManager::new( + litep2p_config.keypair.clone(), + supported_transports, + bandwidth_sink.clone(), + litep2p_config.max_parallel_dials, + ); + + // add known addresses to `TransportManager`, if any exist + if !litep2p_config.known_addresses.is_empty() { + for (peer, addresses) in litep2p_config.known_addresses { + transport_manager.add_known_address(peer, addresses.iter().cloned()); + } + } + + // start notification protocol event loops + for (protocol, config) in litep2p_config.notification_protocols.into_iter() { + tracing::debug!( + target: LOG_TARGET, + ?protocol, + "enable notification protocol", + ); + + let service = transport_manager.register_protocol( + protocol, + config.fallback_names.clone(), + config.codec, + ); + let executor = Arc::clone(&litep2p_config.executor); + litep2p_config.executor.run(Box::pin(async move { + NotificationProtocol::new(service, config, executor).run().await + })); + } + + // start request-response protocol event loops + for (protocol, config) in litep2p_config.request_response_protocols.into_iter() { + tracing::debug!( + target: LOG_TARGET, + ?protocol, + "enable request-response protocol", + ); + + let service = transport_manager.register_protocol( + protocol, + config.fallback_names.clone(), + config.codec, + ); + litep2p_config.executor.run(Box::pin(async move { + RequestResponseProtocol::new(service, config).run().await + })); + } + + // start user protocol event loops + for (protocol_name, protocol) in litep2p_config.user_protocols.into_iter() { + tracing::debug!(target: LOG_TARGET, protocol = ?protocol_name, "enable user protocol"); + + let service = + transport_manager.register_protocol(protocol_name, Vec::new(), protocol.codec()); + litep2p_config.executor.run(Box::pin(async move { + let _ = protocol.run(service).await; + })); + } + + // start ping protocol event loop if enabled + if let Some(ping_config) = litep2p_config.ping.take() { + tracing::debug!( + target: LOG_TARGET, + protocol = ?ping_config.protocol, + "enable ipfs ping protocol", + ); + + let service = transport_manager.register_protocol( + ping_config.protocol.clone(), + Vec::new(), + ping_config.codec, + ); + litep2p_config.executor.run(Box::pin(async move { + Ping::new(service, ping_config).run().await + })); + } + + // start kademlia protocol event loop if enabled + if let Some(kademlia_config) = litep2p_config.kademlia.take() { + tracing::debug!( + target: LOG_TARGET, + protocol_names = ?kademlia_config.protocol_names, + "enable ipfs kademlia protocol", + ); + + let main_protocol = + kademlia_config.protocol_names.get(0).expect("protocol name to exist"); + let fallback_names = kademlia_config.protocol_names.iter().skip(1).cloned().collect(); + + let service = transport_manager.register_protocol( + main_protocol.clone(), + fallback_names, + kademlia_config.codec, + ); + litep2p_config.executor.run(Box::pin(async move { + let _ = Kademlia::new(service, kademlia_config).run().await; + })); + } + + // start identify protocol event loop if enabled + let mut identify_info = match litep2p_config.identify.take() { + None => None, + Some(mut identify_config) => { + tracing::debug!( + target: LOG_TARGET, + protocol = ?identify_config.protocol, + "enable ipfs identify protocol", + ); + + let service = transport_manager.register_protocol( + identify_config.protocol.clone(), + Vec::new(), + identify_config.codec.clone(), + ); + identify_config.public = Some(litep2p_config.keypair.public().into()); + + Some((service, identify_config)) + } + }; + + // start bitswap protocol event loop if enabled + if let Some(bitswap_config) = litep2p_config.bitswap.take() { + tracing::debug!( + target: LOG_TARGET, + protocol = ?bitswap_config.protocol, + "enable ipfs bitswap protocol", + ); + + let service = transport_manager.register_protocol( + bitswap_config.protocol.clone(), + Vec::new(), + bitswap_config.codec, + ); + litep2p_config.executor.run(Box::pin(async move { + Bitswap::new(service, bitswap_config).run().await + })); + } + + // enable tcp transport if the config exists + if let Some(config) = litep2p_config.tcp.take() { + let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); + let (transport, transport_listen_addresses) = + ::new(handle, config)?; + + for address in transport_listen_addresses { + transport_manager.register_listen_address(address.clone()); + listen_addresses.push(address.with(Protocol::P2p( + Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), + ))); + } + + transport_manager.register_transport(SupportedTransport::Tcp, Box::new(transport)); + } + + // enable quic transport if the config exists + if let Some(config) = litep2p_config.quic.take() { + let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); + let (transport, transport_listen_addresses) = + ::new(handle, config)?; + + for address in transport_listen_addresses { + transport_manager.register_listen_address(address.clone()); + listen_addresses.push(address.with(Protocol::P2p( + Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), + ))); + } + + transport_manager.register_transport(SupportedTransport::Quic, Box::new(transport)); + } + + // enable webrtc transport if the config exists + if let Some(config) = litep2p_config.webrtc.take() { + let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); + let (transport, transport_listen_addresses) = + ::new(handle, config)?; + + for address in transport_listen_addresses { + transport_manager.register_listen_address(address.clone()); + listen_addresses.push(address.with(Protocol::P2p( + Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), + ))); + } + + transport_manager.register_transport(SupportedTransport::WebRtc, Box::new(transport)); + } + + // enable websocket transport if the config exists + if let Some(config) = litep2p_config.websocket.take() { + let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor)); + let (transport, transport_listen_addresses) = + ::new(handle, config)?; + + for address in transport_listen_addresses { + transport_manager.register_listen_address(address.clone()); + listen_addresses.push(address.with(Protocol::P2p( + Multihash::from_bytes(&local_peer_id.to_bytes()).unwrap(), + ))); + } + + transport_manager + .register_transport(SupportedTransport::WebSocket, Box::new(transport)); + } + + // enable mdns if the config exists + if let Some(config) = litep2p_config.mdns.take() { + let mdns = Mdns::new(transport_handle, config, listen_addresses.clone())?; + + litep2p_config.executor.run(Box::pin(async move { + let _ = mdns.start().await; + })); + } + + // if identify was enabled, give it the enabled protocols and listen addresses and start it + if let Some((service, mut identify_config)) = identify_info.take() { + identify_config.protocols = transport_manager.protocols().cloned().collect(); + let identify = Identify::new(service, identify_config, listen_addresses.clone()); + + litep2p_config.executor.run(Box::pin(async move { + let _ = identify.run().await; + })); + } + + if transport_manager.installed_transports().count() == 0 { + return Err(Error::Other("No transport specified".to_string())); + } + + // verify that at least one transport is specified + if listen_addresses.is_empty() { + tracing::warn!( + target: LOG_TARGET, + "litep2p started with no listen addresses, cannot accept inbound connections", + ); + } + + Ok(Self { + local_peer_id, + bandwidth_sink, + listen_addresses, + transport_manager, + }) + } + + /// Collect supported transports before initializing the transports themselves. + /// + /// Information of the supported transports is needed to initialize protocols but + /// information about protocols must be known to initialize transports so the initialization + /// has to be split. + fn supported_transports(config: &Litep2pConfig) -> HashSet { + let mut supported_transports = HashSet::new(); + + config + .tcp + .is_some() + .then(|| supported_transports.insert(SupportedTransport::Tcp)); + config + .quic + .is_some() + .then(|| supported_transports.insert(SupportedTransport::Quic)); + config + .websocket + .is_some() + .then(|| supported_transports.insert(SupportedTransport::WebSocket)); + config + .webrtc + .is_some() + .then(|| supported_transports.insert(SupportedTransport::WebRtc)); + + supported_transports + } + + /// Get local peer ID. + pub fn local_peer_id(&self) -> &PeerId { + &self.local_peer_id + } + + /// Get listen address of litep2p. + pub fn listen_addresses(&self) -> impl Iterator { + self.listen_addresses.iter() + } + + /// Get handle to bandwidth sink. + pub fn bandwidth_sink(&self) -> BandwidthSink { + self.bandwidth_sink.clone() + } + + /// Dial peer. + pub async fn dial(&mut self, peer: &PeerId) -> crate::Result<()> { + self.transport_manager.dial(*peer).await + } + + /// Dial address. + pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { + self.transport_manager.dial_address(address).await + } + + /// Add one ore more known addresses for peer. + /// + /// Return value denotes how many addresses were added for the peer. + // Addresses belonging to disabled/unsupported transports will be ignored. + pub fn add_known_address( + &mut self, + peer: PeerId, + address: impl Iterator, + ) -> usize { + self.transport_manager.add_known_address(peer, address) + } + + /// Poll next event. + /// + /// This function must be called in order for litep2p to make progress. + pub async fn next_event(&mut self) -> Option { + loop { + match self.transport_manager.next().await? { + TransportEvent::ConnectionEstablished { peer, endpoint, .. } => + return Some(Litep2pEvent::ConnectionEstablished { peer, endpoint }), + TransportEvent::ConnectionClosed { + peer, + connection_id, + } => + return Some(Litep2pEvent::ConnectionClosed { + peer, + connection_id, + }), + TransportEvent::DialFailure { address, error, .. } => + return Some(Litep2pEvent::DialFailure { address, error }), + _ => {} + } + } + } } #[cfg(test)] mod tests { - use crate::{ - config::ConfigBuilder, - protocol::{libp2p::ping, notification::Config as NotificationConfig}, - types::protocol::ProtocolName, - Litep2p, Litep2pEvent, PeerId, - }; - use multiaddr::{Multiaddr, Protocol}; - use multihash::Multihash; - use std::net::Ipv4Addr; - - #[tokio::test] - async fn initialize_litep2p() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (config1, _service1) = NotificationConfig::new( - ProtocolName::from("/notificaton/1"), - 1337usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (config2, _service2) = NotificationConfig::new( - ProtocolName::from("/notificaton/2"), - 1337usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (ping_config, _ping_event_stream) = ping::Config::default(); - - let config = ConfigBuilder::new() - .with_tcp(Default::default()) - .with_quic(Default::default()) - .with_notification_protocol(config1) - .with_notification_protocol(config2) - .with_libp2p_ping(ping_config) - .build(); - - let _litep2p = Litep2p::new(config).unwrap(); - } - - #[tokio::test] - async fn no_transport_given() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (config1, _service1) = NotificationConfig::new( - ProtocolName::from("/notificaton/1"), - 1337usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (config2, _service2) = NotificationConfig::new( - ProtocolName::from("/notificaton/2"), - 1337usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (ping_config, _ping_event_stream) = ping::Config::default(); - - let config = ConfigBuilder::new() - .with_notification_protocol(config1) - .with_notification_protocol(config2) - .with_libp2p_ping(ping_config) - .build(); - - assert!(Litep2p::new(config).is_err()); - } - - #[tokio::test] - async fn dial_same_address_twice() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (config1, _service1) = NotificationConfig::new( - ProtocolName::from("/notificaton/1"), - 1337usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (config2, _service2) = NotificationConfig::new( - ProtocolName::from("/notificaton/2"), - 1337usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (ping_config, _ping_event_stream) = ping::Config::default(); - - let config = ConfigBuilder::new() - .with_tcp(Default::default()) - .with_quic(Default::default()) - .with_notification_protocol(config1) - .with_notification_protocol(config2) - .with_libp2p_ping(ping_config) - .build(); - - let peer = PeerId::random(); - let address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(255, 254, 253, 252))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - let mut litep2p = Litep2p::new(config).unwrap(); - litep2p.dial_address(address.clone()).await.unwrap(); - litep2p.dial_address(address.clone()).await.unwrap(); - - match litep2p.next_event().await { - Some(Litep2pEvent::DialFailure { .. }) => {}, - _ => panic!("invalid event received"), - } - - // verify that the second same dial was ignored and the dial failure is reported only once - match tokio::time::timeout(std::time::Duration::from_secs(20), litep2p.next_event()).await { - Err(_) => {}, - _ => panic!("invalid event received"), - } - } + use crate::{ + config::ConfigBuilder, + protocol::{libp2p::ping, notification::Config as NotificationConfig}, + types::protocol::ProtocolName, + Litep2p, Litep2pEvent, PeerId, + }; + use multiaddr::{Multiaddr, Protocol}; + use multihash::Multihash; + use std::net::Ipv4Addr; + + #[tokio::test] + async fn initialize_litep2p() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (config1, _service1) = NotificationConfig::new( + ProtocolName::from("/notificaton/1"), + 1337usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (config2, _service2) = NotificationConfig::new( + ProtocolName::from("/notificaton/2"), + 1337usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (ping_config, _ping_event_stream) = ping::Config::default(); + + let config = ConfigBuilder::new() + .with_tcp(Default::default()) + .with_quic(Default::default()) + .with_notification_protocol(config1) + .with_notification_protocol(config2) + .with_libp2p_ping(ping_config) + .build(); + + let _litep2p = Litep2p::new(config).unwrap(); + } + + #[tokio::test] + async fn no_transport_given() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (config1, _service1) = NotificationConfig::new( + ProtocolName::from("/notificaton/1"), + 1337usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (config2, _service2) = NotificationConfig::new( + ProtocolName::from("/notificaton/2"), + 1337usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (ping_config, _ping_event_stream) = ping::Config::default(); + + let config = ConfigBuilder::new() + .with_notification_protocol(config1) + .with_notification_protocol(config2) + .with_libp2p_ping(ping_config) + .build(); + + assert!(Litep2p::new(config).is_err()); + } + + #[tokio::test] + async fn dial_same_address_twice() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (config1, _service1) = NotificationConfig::new( + ProtocolName::from("/notificaton/1"), + 1337usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (config2, _service2) = NotificationConfig::new( + ProtocolName::from("/notificaton/2"), + 1337usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (ping_config, _ping_event_stream) = ping::Config::default(); + + let config = ConfigBuilder::new() + .with_tcp(Default::default()) + .with_quic(Default::default()) + .with_notification_protocol(config1) + .with_notification_protocol(config2) + .with_libp2p_ping(ping_config) + .build(); + + let peer = PeerId::random(); + let address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(255, 254, 253, 252))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + let mut litep2p = Litep2p::new(config).unwrap(); + litep2p.dial_address(address.clone()).await.unwrap(); + litep2p.dial_address(address.clone()).await.unwrap(); + + match litep2p.next_event().await { + Some(Litep2pEvent::DialFailure { .. }) => {} + _ => panic!("invalid event received"), + } + + // verify that the second same dial was ignored and the dial failure is reported only once + match tokio::time::timeout(std::time::Duration::from_secs(20), litep2p.next_event()).await { + Err(_) => {} + _ => panic!("invalid event received"), + } + } } diff --git a/src/mock/substream.rs b/src/mock/substream.rs index 6263b19e..d206c434 100644 --- a/src/mock/substream.rs +++ b/src/mock/substream.rs @@ -24,62 +24,62 @@ use bytes::{Bytes, BytesMut}; use futures::{Sink, Stream}; use std::{ - fmt::Debug, - pin::Pin, - task::{Context, Poll}, + fmt::Debug, + pin::Pin, + task::{Context, Poll}, }; /// Trait which describes the behavior of a mock substream. pub trait Substream: - Debug + Stream> + Sink + Send + Unpin + 'static + Debug + Stream> + Sink + Send + Unpin + 'static { } /// Blanket implementation for [`Substream`]. impl< - T: Debug - + Stream> - + Sink - + Send - + Unpin - + 'static, - > Substream for T + T: Debug + + Stream> + + Sink + + Send + + Unpin + + 'static, + > Substream for T { } mockall::mock! { - #[derive(Debug)] - pub Substream {} - - impl Sink for Substream { - type Error = Error; - - fn poll_ready<'a>( - self: Pin<&mut Self>, - cx: &mut Context<'a> - ) -> Poll>; - - fn start_send(self: Pin<&mut Self>, item: bytes::Bytes) -> Result<(), Error>; - - fn poll_flush<'a>( - self: Pin<&mut Self>, - cx: &mut Context<'a> - ) -> Poll>; - - fn poll_close<'a>( - self: Pin<&mut Self>, - cx: &mut Context<'a> - ) -> Poll>; - } - - impl Stream for Substream { - type Item = crate::Result; - - fn poll_next<'a>( - self: Pin<&mut Self>, - cx: &mut Context<'a> - ) -> Poll>>; - } + #[derive(Debug)] + pub Substream {} + + impl Sink for Substream { + type Error = Error; + + fn poll_ready<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'a> + ) -> Poll>; + + fn start_send(self: Pin<&mut Self>, item: bytes::Bytes) -> Result<(), Error>; + + fn poll_flush<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'a> + ) -> Poll>; + + fn poll_close<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'a> + ) -> Poll>; + } + + impl Stream for Substream { + type Item = crate::Result; + + fn poll_next<'a>( + self: Pin<&mut Self>, + cx: &mut Context<'a> + ) -> Poll>>; + } } /// Dummy substream which just implements `Stream + Sink` and returns `Poll::Pending`/`Ok(())` @@ -87,71 +87,71 @@ mockall::mock! { pub struct DummySubstream {} impl DummySubstream { - /// Create new [`DummySubstream`]. - #[cfg(test)] - pub fn new() -> Self { - Self {} - } + /// Create new [`DummySubstream`]. + #[cfg(test)] + pub fn new() -> Self { + Self {} + } } impl Sink for DummySubstream { - type Error = Error; + type Error = Error; - fn poll_ready<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll> { - Poll::Pending - } + fn poll_ready<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll> { + Poll::Pending + } - fn start_send(self: Pin<&mut Self>, _item: bytes::Bytes) -> Result<(), Error> { - Ok(()) - } + fn start_send(self: Pin<&mut Self>, _item: bytes::Bytes) -> Result<(), Error> { + Ok(()) + } - fn poll_flush<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll> { - Poll::Pending - } + fn poll_flush<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll> { + Poll::Pending + } - fn poll_close<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll> { - Poll::Ready(Ok(())) - } + fn poll_close<'a>(self: Pin<&mut Self>, _cx: &mut Context<'a>) -> Poll> { + Poll::Ready(Ok(())) + } } impl Stream for DummySubstream { - type Item = crate::Result; - - fn poll_next<'a>( - self: Pin<&mut Self>, - _cx: &mut Context<'a>, - ) -> Poll>> { - Poll::Pending - } + type Item = crate::Result; + + fn poll_next<'a>( + self: Pin<&mut Self>, + _cx: &mut Context<'a>, + ) -> Poll>> { + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use futures::SinkExt; - - #[tokio::test] - async fn dummy_substream_sink() { - let mut substream = DummySubstream::new(); - - futures::future::poll_fn(|cx| match substream.poll_ready_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; - - assert!(Pin::new(&mut substream).start_send(bytes::Bytes::new()).is_ok()); - - futures::future::poll_fn(|cx| match substream.poll_flush_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; - - futures::future::poll_fn(|cx| match substream.poll_close_unpin(cx) { - Poll::Ready(Ok(())) => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; - } + use super::*; + use futures::SinkExt; + + #[tokio::test] + async fn dummy_substream_sink() { + let mut substream = DummySubstream::new(); + + futures::future::poll_fn(|cx| match substream.poll_ready_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; + + assert!(Pin::new(&mut substream).start_send(bytes::Bytes::new()).is_ok()); + + futures::future::poll_fn(|cx| match substream.poll_flush_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; + + futures::future::poll_fn(|cx| match substream.poll_close_unpin(cx) { + Poll::Ready(Ok(())) => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; + } } diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index f0a4baf7..2a8a025e 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -21,23 +21,23 @@ //! Protocol negotiation strategies for the peer acting as the dialer. use crate::{ - codec::unsigned_varint::UnsignedVarint, - error::{self, Error}, - multistream_select::{ - protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}, - Negotiated, NegotiationError, Version, - }, - types::protocol::ProtocolName, + codec::unsigned_varint::UnsignedVarint, + error::{self, Error}, + multistream_select::{ + protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}, + Negotiated, NegotiationError, Version, + }, + types::protocol::ProtocolName, }; use bytes::BytesMut; use futures::prelude::*; use rustls::internal::msgs::hsjoiner::HandshakeJoiner; use std::{ - convert::TryFrom as _, - iter, mem, - pin::Pin, - task::{Context, Poll}, + convert::TryFrom as _, + iter, mem, + pin::Pin, + task::{Context, Poll}, }; const LOG_TARGET: &str = "litep2p::multistream-select"; @@ -56,284 +56,293 @@ const LOG_TARGET: &str = "litep2p::multistream-select"; /// protocol upgrades may thus proceed by deployments with updated listeners, /// eventually followed by deployments of dialers choosing the newer protocol. pub fn dialer_select_proto( - inner: R, - protocols: I, - version: Version, + inner: R, + protocols: I, + version: Version, ) -> DialerSelectFuture where - R: AsyncRead + AsyncWrite, - I: IntoIterator, - I::Item: AsRef<[u8]>, + R: AsyncRead + AsyncWrite, + I: IntoIterator, + I::Item: AsRef<[u8]>, { - let protocols = protocols.into_iter().peekable(); - DialerSelectFuture { - version, - protocols, - state: State::SendHeader { io: MessageIO::new(inner) }, - } + let protocols = protocols.into_iter().peekable(); + DialerSelectFuture { + version, + protocols, + state: State::SendHeader { + io: MessageIO::new(inner), + }, + } } /// A `Future` returned by [`dialer_select_proto`] which negotiates /// a protocol iteratively by considering one protocol after the other. #[pin_project::pin_project] pub struct DialerSelectFuture { - // TODO: It would be nice if eventually N = I::Item = Protocol. - protocols: iter::Peekable, - state: State, - version: Version, + // TODO: It would be nice if eventually N = I::Item = Protocol. + protocols: iter::Peekable, + state: State, + version: Version, } enum State { - SendHeader { io: MessageIO }, - SendProtocol { io: MessageIO, protocol: N }, - FlushProtocol { io: MessageIO, protocol: N }, - AwaitProtocol { io: MessageIO, protocol: N }, - Done, + SendHeader { io: MessageIO }, + SendProtocol { io: MessageIO, protocol: N }, + FlushProtocol { io: MessageIO, protocol: N }, + AwaitProtocol { io: MessageIO, protocol: N }, + Done, } impl Future for DialerSelectFuture where - // The Unpin bound here is required because we produce - // a `Negotiated` as the output. It also makes - // the implementation considerably easier to write. - R: AsyncRead + AsyncWrite + Unpin, - I: Iterator, - I::Item: AsRef<[u8]>, + // The Unpin bound here is required because we produce + // a `Negotiated` as the output. It also makes + // the implementation considerably easier to write. + R: AsyncRead + AsyncWrite + Unpin, + I: Iterator, + I::Item: AsRef<[u8]>, { - type Output = Result<(I::Item, Negotiated), NegotiationError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - loop { - match mem::replace(this.state, State::Done) { - State::SendHeader { mut io } => { - match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, - Poll::Pending => { - *this.state = State::SendHeader { io }; - return Poll::Pending; - }, - } - - let h = HeaderLine::from(*this.version); - if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) { - return Poll::Ready(Err(From::from(err))); - } - - let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; - - // The dialer always sends the header and the first protocol - // proposal in one go for efficiency. - *this.state = State::SendProtocol { io, protocol }; - }, - - State::SendProtocol { mut io, protocol } => { - match Pin::new(&mut io).poll_ready(cx)? { - Poll::Ready(()) => {}, - Poll::Pending => { - *this.state = State::SendProtocol { io, protocol }; - return Poll::Pending; - }, - } - - let p = Protocol::try_from(protocol.as_ref())?; - if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) { - return Poll::Ready(Err(From::from(err))); - } - tracing::debug!(target: LOG_TARGET, "Dialer: Proposed protocol: {}", p); - - if this.protocols.peek().is_some() { - *this.state = State::FlushProtocol { io, protocol } - } else { - match this.version { - Version::V1 => *this.state = State::FlushProtocol { io, protocol }, - // This is the only effect that `V1Lazy` has compared to `V1`: - // Optimistically settling on the only protocol that - // the dialer supports for this negotiation. Notably, - // the dialer expects a regular `V1` response. - Version::V1Lazy => { - tracing::debug!( - target: LOG_TARGET, - "Dialer: Expecting proposed protocol: {}", - p - ); - let hl = HeaderLine::from(Version::V1Lazy); - let io = Negotiated::expecting(io.into_reader(), p, Some(hl)); - return Poll::Ready(Ok((protocol, io))); - }, - } - } - }, - - State::FlushProtocol { mut io, protocol } => { - match Pin::new(&mut io).poll_flush(cx)? { - Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol }, - Poll::Pending => { - *this.state = State::FlushProtocol { io, protocol }; - return Poll::Pending; - }, - } - }, - - State::AwaitProtocol { mut io, protocol } => { - let msg = match Pin::new(&mut io).poll_next(cx)? { - Poll::Ready(Some(msg)) => msg, - Poll::Pending => { - *this.state = State::AwaitProtocol { io, protocol }; - return Poll::Pending; - }, - // Treat EOF error as [`NegotiationError::Failed`], not as - // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O - // stream as a permissible way to "gracefully" fail a negotiation. - Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), - }; - - match msg { - Message::Header(v) if v == HeaderLine::from(*this.version) => { - *this.state = State::AwaitProtocol { io, protocol }; - }, - Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { - tracing::debug!( - target: LOG_TARGET, - "Dialer: Received confirmation for protocol: {}", - p - ); - let io = Negotiated::completed(io.into_inner()); - return Poll::Ready(Ok((protocol, io))); - }, - Message::NotAvailable => { - tracing::debug!( - target: LOG_TARGET, - "Dialer: Received rejection of protocol: {}", - String::from_utf8_lossy(protocol.as_ref()) - ); - let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; - *this.state = State::SendProtocol { io, protocol } - }, - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), - } - }, - - State::Done => panic!("State::poll called after completion"), - } - } - } + type Output = Result<(I::Item, Negotiated), NegotiationError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + loop { + match mem::replace(this.state, State::Done) { + State::SendHeader { mut io } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {} + Poll::Pending => { + *this.state = State::SendHeader { io }; + return Poll::Pending; + } + } + + let h = HeaderLine::from(*this.version); + if let Err(err) = Pin::new(&mut io).start_send(Message::Header(h)) { + return Poll::Ready(Err(From::from(err))); + } + + let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + + // The dialer always sends the header and the first protocol + // proposal in one go for efficiency. + *this.state = State::SendProtocol { io, protocol }; + } + + State::SendProtocol { mut io, protocol } => { + match Pin::new(&mut io).poll_ready(cx)? { + Poll::Ready(()) => {} + Poll::Pending => { + *this.state = State::SendProtocol { io, protocol }; + return Poll::Pending; + } + } + + let p = Protocol::try_from(protocol.as_ref())?; + if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) { + return Poll::Ready(Err(From::from(err))); + } + tracing::debug!(target: LOG_TARGET, "Dialer: Proposed protocol: {}", p); + + if this.protocols.peek().is_some() { + *this.state = State::FlushProtocol { io, protocol } + } else { + match this.version { + Version::V1 => *this.state = State::FlushProtocol { io, protocol }, + // This is the only effect that `V1Lazy` has compared to `V1`: + // Optimistically settling on the only protocol that + // the dialer supports for this negotiation. Notably, + // the dialer expects a regular `V1` response. + Version::V1Lazy => { + tracing::debug!( + target: LOG_TARGET, + "Dialer: Expecting proposed protocol: {}", + p + ); + let hl = HeaderLine::from(Version::V1Lazy); + let io = Negotiated::expecting(io.into_reader(), p, Some(hl)); + return Poll::Ready(Ok((protocol, io))); + } + } + } + } + + State::FlushProtocol { mut io, protocol } => { + match Pin::new(&mut io).poll_flush(cx)? { + Poll::Ready(()) => *this.state = State::AwaitProtocol { io, protocol }, + Poll::Pending => { + *this.state = State::FlushProtocol { io, protocol }; + return Poll::Pending; + } + } + } + + State::AwaitProtocol { mut io, protocol } => { + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + *this.state = State::AwaitProtocol { io, protocol }; + return Poll::Pending; + } + // Treat EOF error as [`NegotiationError::Failed`], not as + // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O + // stream as a permissible way to "gracefully" fail a negotiation. + Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), + }; + + match msg { + Message::Header(v) if v == HeaderLine::from(*this.version) => { + *this.state = State::AwaitProtocol { io, protocol }; + } + Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { + tracing::debug!( + target: LOG_TARGET, + "Dialer: Received confirmation for protocol: {}", + p + ); + let io = Negotiated::completed(io.into_inner()); + return Poll::Ready(Ok((protocol, io))); + } + Message::NotAvailable => { + tracing::debug!( + target: LOG_TARGET, + "Dialer: Received rejection of protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); + let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?; + *this.state = State::SendProtocol { io, protocol } + } + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), + } + } + + State::Done => panic!("State::poll called after completion"), + } + } + } } /// `multistream-select` handshake result for dialer. #[derive(Debug)] pub enum HandshakeResult { - /// Handshake is not complete, data missing. - NotReady, - - /// Handshake has succeeded. - /// - /// The returned tuple contains the negotiated protocol and response - /// that must be sent to remote peer. - Succeeded(ProtocolName), + /// Handshake is not complete, data missing. + NotReady, + + /// Handshake has succeeded. + /// + /// The returned tuple contains the negotiated protocol and response + /// that must be sent to remote peer. + Succeeded(ProtocolName), } /// Handshake state. #[derive(Debug)] enum HandshakeState { - /// Wainting to receive any response from remote peer. - WaitingResponse, + /// Wainting to receive any response from remote peer. + WaitingResponse, - /// Waiting to receive the actual application protocol from remote peer. - WaitingProtocol, + /// Waiting to receive the actual application protocol from remote peer. + WaitingProtocol, } /// `multistream-select` dialer handshake state. #[derive(Debug)] pub struct DialerState { - /// Proposed main protocol. - protocol: ProtocolName, + /// Proposed main protocol. + protocol: ProtocolName, - /// Fallback names of the main protocol. - fallback_names: Vec, + /// Fallback names of the main protocol. + fallback_names: Vec, - /// Dialer handshake state. - state: HandshakeState, + /// Dialer handshake state. + state: HandshakeState, } // TODO: tests impl DialerState { - /// Propose protocol to remote peer. - /// - /// Return [`DialerState`] which is used to drive forward the negotiation and an encoded - /// `multistream-select` message that contains the protocol proposal for the substream. - pub fn propose( - protocol: ProtocolName, - fallback_names: Vec, - ) -> crate::Result<(Self, Vec)> { - // encode `/multistream-select/1.0.0` header - let mut bytes = BytesMut::with_capacity(64); - let message = Message::Header(HeaderLine::V1); - let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode proposed protocol - let mut proto_bytes = BytesMut::with_capacity(512); - let message = Message::Protocol(Protocol::try_from(protocol.as_bytes()).unwrap()); - let _ = message.encode(&mut proto_bytes).map_err(|_| Error::InvalidData)?; - let proto_bytes = UnsignedVarint::encode(proto_bytes)?; - - // TODO: add fallback names - - header.append(&mut proto_bytes.into()); - - Ok((Self { protocol, fallback_names, state: HandshakeState::WaitingResponse }, header)) - } - - /// Register response to [`DialerState`]. - pub fn register_response(&mut self, payload: Vec) -> crate::Result { - let Message::Protocols(protocols) = - Message::decode(payload.into()).map_err(|_| Error::InvalidData)? - else { - return Err(Error::NegotiationError(error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - ))); - }; - - let mut protocol_iter = protocols.into_iter(); - loop { - match (&self.state, protocol_iter.next()) { - (HandshakeState::WaitingResponse, None) => return Err(Error::InvalidState), - (HandshakeState::WaitingResponse, Some(protocol)) => { - let header = Protocol::try_from(&b"/multistream/1.0.0"[..]) - .expect("valid multitstream-select header"); - - if protocol == header { - self.state = HandshakeState::WaitingProtocol; - } else { - return Err(Error::NegotiationError( - error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - ), - )); - } - }, - (HandshakeState::WaitingProtocol, Some(protocol)) => { - if self.protocol.as_bytes() == protocol.as_ref() { - return Ok(HandshakeResult::Succeeded(self.protocol.clone())); - } - - // TODO: zzz - for fallback in &self.fallback_names { - if fallback.as_bytes() == protocol.as_ref() { - return Ok(HandshakeResult::Succeeded(self.protocol.clone())); - } - } - - return Err(Error::NegotiationError( - error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), - )); - }, - (HandshakeState::WaitingProtocol, None) => { - return Ok(HandshakeResult::NotReady); - }, - } - } - } + /// Propose protocol to remote peer. + /// + /// Return [`DialerState`] which is used to drive forward the negotiation and an encoded + /// `multistream-select` message that contains the protocol proposal for the substream. + pub fn propose( + protocol: ProtocolName, + fallback_names: Vec, + ) -> crate::Result<(Self, Vec)> { + // encode `/multistream-select/1.0.0` header + let mut bytes = BytesMut::with_capacity(64); + let message = Message::Header(HeaderLine::V1); + let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData)?; + let mut header = UnsignedVarint::encode(bytes)?; + + // encode proposed protocol + let mut proto_bytes = BytesMut::with_capacity(512); + let message = Message::Protocol(Protocol::try_from(protocol.as_bytes()).unwrap()); + let _ = message.encode(&mut proto_bytes).map_err(|_| Error::InvalidData)?; + let proto_bytes = UnsignedVarint::encode(proto_bytes)?; + + // TODO: add fallback names + + header.append(&mut proto_bytes.into()); + + Ok(( + Self { + protocol, + fallback_names, + state: HandshakeState::WaitingResponse, + }, + header, + )) + } + + /// Register response to [`DialerState`]. + pub fn register_response(&mut self, payload: Vec) -> crate::Result { + let Message::Protocols(protocols) = + Message::decode(payload.into()).map_err(|_| Error::InvalidData)? + else { + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )); + }; + + let mut protocol_iter = protocols.into_iter(); + loop { + match (&self.state, protocol_iter.next()) { + (HandshakeState::WaitingResponse, None) => return Err(Error::InvalidState), + (HandshakeState::WaitingResponse, Some(protocol)) => { + let header = Protocol::try_from(&b"/multistream/1.0.0"[..]) + .expect("valid multitstream-select header"); + + if protocol == header { + self.state = HandshakeState::WaitingProtocol; + } else { + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError( + NegotiationError::Failed, + ), + )); + } + } + (HandshakeState::WaitingProtocol, Some(protocol)) => { + if self.protocol.as_bytes() == protocol.as_ref() { + return Ok(HandshakeResult::Succeeded(self.protocol.clone())); + } + + // TODO: zzz + for fallback in &self.fallback_names { + if fallback.as_bytes() == protocol.as_ref() { + return Ok(HandshakeResult::Succeeded(self.protocol.clone())); + } + } + + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )); + } + (HandshakeState::WaitingProtocol, None) => { + return Ok(HandshakeResult::NotReady); + } + } + } + } } diff --git a/src/multistream_select/length_delimited.rs b/src/multistream_select/length_delimited.rs index 15218098..9d5d3ce3 100644 --- a/src/multistream_select/length_delimited.rs +++ b/src/multistream_select/length_delimited.rs @@ -21,11 +21,11 @@ use bytes::{Buf as _, BufMut as _, Bytes, BytesMut}; use futures::{io::IoSlice, prelude::*}; use std::{ - convert::TryFrom as _, - io, - pin::Pin, - task::{Context, Poll}, - u16, + convert::TryFrom as _, + io, + pin::Pin, + task::{Context, Poll}, + u16, }; const MAX_LEN_BYTES: u16 = 2; @@ -42,245 +42,251 @@ const LOG_TARGET: &str = "litep2p::multistream-select"; #[pin_project::pin_project] #[derive(Debug)] pub struct LengthDelimited { - /// The inner I/O resource. - #[pin] - inner: R, - /// Read buffer for a single incoming unsigned-varint length-delimited frame. - read_buffer: BytesMut, - /// Write buffer for outgoing unsigned-varint length-delimited frames. - write_buffer: BytesMut, - /// The current read state, alternating between reading a frame - /// length and reading a frame payload. - read_state: ReadState, + /// The inner I/O resource. + #[pin] + inner: R, + /// Read buffer for a single incoming unsigned-varint length-delimited frame. + read_buffer: BytesMut, + /// Write buffer for outgoing unsigned-varint length-delimited frames. + write_buffer: BytesMut, + /// The current read state, alternating between reading a frame + /// length and reading a frame payload. + read_state: ReadState, } #[derive(Debug, Copy, Clone, PartialEq, Eq)] enum ReadState { - /// We are currently reading the length of the next frame of data. - ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize }, - /// We are currently reading the frame of data itself. - ReadData { len: u16, pos: usize }, + /// We are currently reading the length of the next frame of data. + ReadLength { + buf: [u8; MAX_LEN_BYTES as usize], + pos: usize, + }, + /// We are currently reading the frame of data itself. + ReadData { len: u16, pos: usize }, } impl Default for ReadState { - fn default() -> Self { - ReadState::ReadLength { buf: [0; MAX_LEN_BYTES as usize], pos: 0 } - } + fn default() -> Self { + ReadState::ReadLength { + buf: [0; MAX_LEN_BYTES as usize], + pos: 0, + } + } } impl LengthDelimited { - /// Creates a new I/O resource for reading and writing unsigned-varint - /// length delimited frames. - pub fn new(inner: R) -> LengthDelimited { - LengthDelimited { - inner, - read_state: ReadState::default(), - read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE), - write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize), - } - } - - /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream. - /// - /// # Panic - /// - /// Will panic if called while there is data in the read or write buffer. - /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields - /// a new `Bytes` frame. The write buffer is guaranteed to be empty after - /// flushing. - pub fn into_inner(self) -> R { - assert!(self.read_buffer.is_empty()); - assert!(self.write_buffer.is_empty()); - self.inner - } - - /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the - /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying - /// I/O stream. - /// - /// This is typically done if further uvi-framed messages are expected to be - /// received but no more such messages are written, allowing the writing of - /// follow-up protocol data to commence. - pub fn into_reader(self) -> LengthDelimitedReader { - LengthDelimitedReader { inner: self } - } - - /// Writes all buffered frame data to the underlying I/O stream, - /// _without flushing it_. - /// - /// After this method returns `Poll::Ready`, the write buffer of frames - /// submitted to the `Sink` is guaranteed to be empty. - pub fn poll_write_buffer( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> - where - R: AsyncWrite, - { - let mut this = self.project(); - - while !this.write_buffer.is_empty() { - match this.inner.as_mut().poll_write(cx, this.write_buffer) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(0)) => - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "Failed to write buffered frame.", - ))), - Poll::Ready(Ok(n)) => this.write_buffer.advance(n), - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - } - } - - Poll::Ready(Ok(())) - } + /// Creates a new I/O resource for reading and writing unsigned-varint + /// length delimited frames. + pub fn new(inner: R) -> LengthDelimited { + LengthDelimited { + inner, + read_state: ReadState::default(), + read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE), + write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize), + } + } + + /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream. + /// + /// # Panic + /// + /// Will panic if called while there is data in the read or write buffer. + /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields + /// a new `Bytes` frame. The write buffer is guaranteed to be empty after + /// flushing. + pub fn into_inner(self) -> R { + assert!(self.read_buffer.is_empty()); + assert!(self.write_buffer.is_empty()); + self.inner + } + + /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the + /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying + /// I/O stream. + /// + /// This is typically done if further uvi-framed messages are expected to be + /// received but no more such messages are written, allowing the writing of + /// follow-up protocol data to commence. + pub fn into_reader(self) -> LengthDelimitedReader { + LengthDelimitedReader { inner: self } + } + + /// Writes all buffered frame data to the underlying I/O stream, + /// _without flushing it_. + /// + /// After this method returns `Poll::Ready`, the write buffer of frames + /// submitted to the `Sink` is guaranteed to be empty. + pub fn poll_write_buffer( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + R: AsyncWrite, + { + let mut this = self.project(); + + while !this.write_buffer.is_empty() { + match this.inner.as_mut().poll_write(cx, this.write_buffer) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "Failed to write buffered frame.", + ))), + Poll::Ready(Ok(n)) => this.write_buffer.advance(n), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + Poll::Ready(Ok(())) + } } impl Stream for LengthDelimited where - R: AsyncRead, + R: AsyncRead, { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - match this.read_state { - ReadState::ReadLength { buf, pos } => { - match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) { - Poll::Ready(Ok(0)) => - if *pos == 0 { - return Poll::Ready(None); - } else { - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); - }, - Poll::Ready(Ok(n)) => { - debug_assert_eq!(n, 1); - *pos += n; - }, - Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), - Poll::Pending => return Poll::Pending, - }; - - if (buf[*pos - 1] & 0x80) == 0 { - // MSB is not set, indicating the end of the length prefix. - let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| { - tracing::debug!(target: LOG_TARGET, "invalid length prefix: {}", e); - io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") - })?; - - if len >= 1 { - *this.read_state = ReadState::ReadData { len, pos: 0 }; - this.read_buffer.resize(len as usize, 0); - } else { - debug_assert_eq!(len, 0); - *this.read_state = ReadState::default(); - return Poll::Ready(Some(Ok(Bytes::new()))); - } - } else if *pos == MAX_LEN_BYTES as usize { - // MSB signals more length bytes but we have already read the maximum. - // See the module documentation about the max frame len. - return Poll::Ready(Some(Err(io::Error::new( - io::ErrorKind::InvalidData, - "Maximum frame length exceeded", - )))); - } - }, - ReadState::ReadData { len, pos } => { - match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) { - Poll::Ready(Ok(0)) => - return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))), - Poll::Ready(Ok(n)) => *pos += n, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), - }; - - if *pos == *len as usize { - // Finished reading the frame. - let frame = this.read_buffer.split_off(0).freeze(); - *this.read_state = ReadState::default(); - return Poll::Ready(Some(Ok(frame))); - } - }, - } - } - } + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + loop { + match this.read_state { + ReadState::ReadLength { buf, pos } => { + match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) { + Poll::Ready(Ok(0)) => + if *pos == 0 { + return Poll::Ready(None); + } else { + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))); + }, + Poll::Ready(Ok(n)) => { + debug_assert_eq!(n, 1); + *pos += n; + } + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), + Poll::Pending => return Poll::Pending, + }; + + if (buf[*pos - 1] & 0x80) == 0 { + // MSB is not set, indicating the end of the length prefix. + let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| { + tracing::debug!(target: LOG_TARGET, "invalid length prefix: {}", e); + io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") + })?; + + if len >= 1 { + *this.read_state = ReadState::ReadData { len, pos: 0 }; + this.read_buffer.resize(len as usize, 0); + } else { + debug_assert_eq!(len, 0); + *this.read_state = ReadState::default(); + return Poll::Ready(Some(Ok(Bytes::new()))); + } + } else if *pos == MAX_LEN_BYTES as usize { + // MSB signals more length bytes but we have already read the maximum. + // See the module documentation about the max frame len. + return Poll::Ready(Some(Err(io::Error::new( + io::ErrorKind::InvalidData, + "Maximum frame length exceeded", + )))); + } + } + ReadState::ReadData { len, pos } => { + match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) { + Poll::Ready(Ok(0)) => + return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))), + Poll::Ready(Ok(n)) => *pos += n, + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))), + }; + + if *pos == *len as usize { + // Finished reading the frame. + let frame = this.read_buffer.split_off(0).freeze(); + *this.read_state = ReadState::default(); + return Poll::Ready(Some(Ok(frame))); + } + } + } + } + } } impl Sink for LengthDelimited where - R: AsyncWrite, + R: AsyncWrite, { - type Error = io::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Use the maximum frame length also as a (soft) upper limit - // for the entire write buffer. The actual (hard) limit is thus - // implied to be roughly 2 * MAX_FRAME_SIZE. - if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize { - match self.as_mut().poll_write_buffer(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - - debug_assert!(self.as_mut().project().write_buffer.is_empty()); - } - - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - let this = self.project(); - - let len = match u16::try_from(item.len()) { - Ok(len) if len <= MAX_FRAME_SIZE => len, - _ => - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Maximum frame size exceeded.", - )), - }; - - let mut uvi_buf = unsigned_varint::encode::u16_buffer(); - let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf); - this.write_buffer.reserve(len as usize + uvi_len.len()); - this.write_buffer.put(uvi_len); - this.write_buffer.put(item); - - Ok(()) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Write all buffered frame data to the underlying I/O stream. - match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - - let this = self.project(); - debug_assert!(this.write_buffer.is_empty()); - - // Flush the underlying I/O stream. - this.inner.poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Write all buffered frame data to the underlying I/O stream. - match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - - let this = self.project(); - debug_assert!(this.write_buffer.is_empty()); - - // Close the underlying I/O stream. - this.inner.poll_close(cx) - } + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Use the maximum frame length also as a (soft) upper limit + // for the entire write buffer. The actual (hard) limit is thus + // implied to be roughly 2 * MAX_FRAME_SIZE. + if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize { + match self.as_mut().poll_write_buffer(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + + debug_assert!(self.as_mut().project().write_buffer.is_empty()); + } + + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let this = self.project(); + + let len = match u16::try_from(item.len()) { + Ok(len) if len <= MAX_FRAME_SIZE => len, + _ => + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Maximum frame size exceeded.", + )), + }; + + let mut uvi_buf = unsigned_varint::encode::u16_buffer(); + let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf); + this.write_buffer.reserve(len as usize + uvi_len.len()); + this.write_buffer.put(uvi_len); + this.write_buffer.put(item); + + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Write all buffered frame data to the underlying I/O stream. + match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + + let this = self.project(); + debug_assert!(this.write_buffer.is_empty()); + + // Flush the underlying I/O stream. + this.inner.poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Write all buffered frame data to the underlying I/O stream. + match LengthDelimited::poll_write_buffer(self.as_mut(), cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + + let this = self.project(); + debug_assert!(this.write_buffer.is_empty()); + + // Close the underlying I/O stream. + this.inner.poll_close(cx) + } } /// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited @@ -288,86 +294,86 @@ where #[pin_project::pin_project] #[derive(Debug)] pub struct LengthDelimitedReader { - #[pin] - inner: LengthDelimited, + #[pin] + inner: LengthDelimited, } impl LengthDelimitedReader { - /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream. - /// - /// This method is guaranteed not to drop any data read from or not yet - /// submitted to the underlying I/O stream. - /// - /// # Panic - /// - /// Will panic if called while there is data in the read or write buffer. - /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`] - /// yield a new `Message`. The write buffer is guaranteed to be empty whenever - /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after - /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`]. - pub fn into_inner(self) -> R { - self.inner.into_inner() - } + /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream. + /// + /// This method is guaranteed not to drop any data read from or not yet + /// submitted to the underlying I/O stream. + /// + /// # Panic + /// + /// Will panic if called while there is data in the read or write buffer. + /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`] + /// yield a new `Message`. The write buffer is guaranteed to be empty whenever + /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after + /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`]. + pub fn into_inner(self) -> R { + self.inner.into_inner() + } } impl Stream for LengthDelimitedReader where - R: AsyncRead, + R: AsyncRead, { - type Item = Result; + type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_next(cx) - } + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } } impl AsyncWrite for LengthDelimitedReader where - R: AsyncWrite, + R: AsyncWrite, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - // `this` here designates the `LengthDelimited`. - let mut this = self.project().inner; - - // We need to flush any data previously written with the `LengthDelimited`. - match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - debug_assert!(this.write_buffer.is_empty()); - - this.project().inner.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_close(cx) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - // `this` here designates the `LengthDelimited`. - let mut this = self.project().inner; - - // We need to flush any data previously written with the `LengthDelimited`. - match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - } - debug_assert!(this.write_buffer.is_empty()); - - this.project().inner.poll_write_vectored(cx, bufs) - } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + // `this` here designates the `LengthDelimited`. + let mut this = self.project().inner; + + // We need to flush any data previously written with the `LengthDelimited`. + match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + debug_assert!(this.write_buffer.is_empty()); + + this.project().inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + // `this` here designates the `LengthDelimited`. + let mut this = self.project().inner; + + // We need to flush any data previously written with the `LengthDelimited`. + match LengthDelimited::poll_write_buffer(this.as_mut(), cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => return Poll::Pending, + } + debug_assert!(this.write_buffer.is_empty()); + + this.project().inner.poll_write_vectored(cx, bufs) + } } diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index db8b859b..4217a332 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -22,24 +22,24 @@ //! in a multistream-select protocol negotiation. use crate::{ - codec::unsigned_varint::UnsignedVarint, - error::{self, Error}, - multistream_select::{ - protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}, - Negotiated, NegotiationError, - }, - types::protocol::ProtocolName, + codec::unsigned_varint::UnsignedVarint, + error::{self, Error}, + multistream_select::{ + protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError}, + Negotiated, NegotiationError, + }, + types::protocol::ProtocolName, }; use bytes::{Bytes, BytesMut}; use futures::prelude::*; use smallvec::SmallVec; use std::{ - convert::TryFrom as _, - iter::FromIterator, - mem, - pin::Pin, - task::{Context, Poll}, + convert::TryFrom as _, + iter::FromIterator, + mem, + pin::Pin, + task::{Context, Poll}, }; const LOG_TARGET: &str = "litep2p::multistream-select"; @@ -53,242 +53,273 @@ const LOG_TARGET: &str = "litep2p::multistream-select"; /// a [`Negotiated`] I/O stream. pub fn listener_select_proto(inner: R, protocols: I) -> ListenerSelectFuture where - R: AsyncRead + AsyncWrite, - I: IntoIterator, - I::Item: AsRef<[u8]>, + R: AsyncRead + AsyncWrite, + I: IntoIterator, + I::Item: AsRef<[u8]>, { - let protocols = protocols.into_iter().filter_map(|n| match Protocol::try_from(n.as_ref()) { - Ok(p) => Some((n, p)), - Err(e) => { - tracing::warn!( - target: LOG_TARGET, - "Listener: Ignoring invalid protocol: {} due to {}", - String::from_utf8_lossy(n.as_ref()), - e - ); - None - }, - }); - ListenerSelectFuture { - protocols: SmallVec::from_iter(protocols), - state: State::RecvHeader { io: MessageIO::new(inner) }, - last_sent_na: false, - } + let protocols = protocols.into_iter().filter_map(|n| match Protocol::try_from(n.as_ref()) { + Ok(p) => Some((n, p)), + Err(e) => { + tracing::warn!( + target: LOG_TARGET, + "Listener: Ignoring invalid protocol: {} due to {}", + String::from_utf8_lossy(n.as_ref()), + e + ); + None + } + }); + ListenerSelectFuture { + protocols: SmallVec::from_iter(protocols), + state: State::RecvHeader { + io: MessageIO::new(inner), + }, + last_sent_na: false, + } } /// The `Future` returned by [`listener_select_proto`] that performs a /// multistream-select protocol negotiation on an underlying I/O stream. #[pin_project::pin_project] pub struct ListenerSelectFuture { - // TODO: It would be nice if eventually N = Protocol, which has a - // few more implications on the API. - protocols: SmallVec<[(N, Protocol); 8]>, - state: State, - /// Whether the last message sent was a protocol rejection (i.e. `na\n`). - /// - /// If the listener reads garbage or EOF after such a rejection, - /// the dialer is likely using `V1Lazy` and negotiation must be - /// considered failed, but not with a protocol violation or I/O - /// error. - last_sent_na: bool, + // TODO: It would be nice if eventually N = Protocol, which has a + // few more implications on the API. + protocols: SmallVec<[(N, Protocol); 8]>, + state: State, + /// Whether the last message sent was a protocol rejection (i.e. `na\n`). + /// + /// If the listener reads garbage or EOF after such a rejection, + /// the dialer is likely using `V1Lazy` and negotiation must be + /// considered failed, but not with a protocol violation or I/O + /// error. + last_sent_na: bool, } enum State { - RecvHeader { io: MessageIO }, - SendHeader { io: MessageIO }, - RecvMessage { io: MessageIO }, - SendMessage { io: MessageIO, message: Message, protocol: Option }, - Flush { io: MessageIO, protocol: Option }, - Done, + RecvHeader { + io: MessageIO, + }, + SendHeader { + io: MessageIO, + }, + RecvMessage { + io: MessageIO, + }, + SendMessage { + io: MessageIO, + message: Message, + protocol: Option, + }, + Flush { + io: MessageIO, + protocol: Option, + }, + Done, } impl Future for ListenerSelectFuture where - // The Unpin bound here is required because we - // produce a `Negotiated` as the output. - // It also makes the implementation considerably - // easier to write. - R: AsyncRead + AsyncWrite + Unpin, - N: AsRef<[u8]> + Clone, + // The Unpin bound here is required because we + // produce a `Negotiated` as the output. + // It also makes the implementation considerably + // easier to write. + R: AsyncRead + AsyncWrite + Unpin, + N: AsRef<[u8]> + Clone, { - type Output = Result<(N, Negotiated), NegotiationError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - loop { - match mem::replace(this.state, State::Done) { - State::RecvHeader { mut io } => { - match io.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(Message::Header(h)))) => match h { - HeaderLine::V1 => *this.state = State::SendHeader { io }, - }, - Poll::Ready(Some(Ok(_))) => - return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), - Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))), - // Treat EOF error as [`NegotiationError::Failed`], not as - // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O - // stream as a permissible way to "gracefully" fail a negotiation. - Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), - Poll::Pending => { - *this.state = State::RecvHeader { io }; - return Poll::Pending; - }, - } - }, - - State::SendHeader { mut io } => { - match Pin::new(&mut io).poll_ready(cx) { - Poll::Pending => { - *this.state = State::SendHeader { io }; - return Poll::Pending; - }, - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), - } - - let msg = Message::Header(HeaderLine::V1); - if let Err(err) = Pin::new(&mut io).start_send(msg) { - return Poll::Ready(Err(From::from(err))); - } - - *this.state = State::Flush { io, protocol: None }; - }, - - State::RecvMessage { mut io } => { - let msg = match Pin::new(&mut io).poll_next(cx) { - Poll::Ready(Some(Ok(msg))) => msg, - // Treat EOF error as [`NegotiationError::Failed`], not as - // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O - // stream as a permissible way to "gracefully" fail a negotiation. - // - // This is e.g. important when a listener rejects a protocol with - // [`Message::NotAvailable`] and the dialer does not have alternative - // protocols to propose. Then the dialer will stop the negotiation and drop - // the corresponding stream. As a listener this EOF should be interpreted as - // a failed negotiation. - Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), - Poll::Pending => { - *this.state = State::RecvMessage { io }; - return Poll::Pending; - }, - Poll::Ready(Some(Err(err))) => { - if *this.last_sent_na { - // When we read garbage or EOF after having already rejected a - // protocol, the dialer is most likely using `V1Lazy` and has - // optimistically settled on this protocol, so this is really a - // failed negotiation, not a protocol violation. In this case - // the dialer also raises `NegotiationError::Failed` when finally - // reading the `N/A` response. - if let ProtocolError::InvalidMessage = &err { - tracing::trace!( - target: LOG_TARGET, - "Listener: Negotiation failed with invalid \ - message after protocol rejection." - ); - return Poll::Ready(Err(NegotiationError::Failed)); - } - if let ProtocolError::IoError(e) = &err { - if e.kind() == std::io::ErrorKind::UnexpectedEof { - tracing::trace!( - target: LOG_TARGET, - "Listener: Negotiation failed with EOF \ - after protocol rejection." - ); - return Poll::Ready(Err(NegotiationError::Failed)); - } - } - } - - return Poll::Ready(Err(From::from(err))); - }, - }; - - match msg { - Message::ListProtocols => { - let supported = - this.protocols.iter().map(|(_, p)| p).cloned().collect(); - let message = Message::Protocols(supported); - *this.state = State::SendMessage { io, message, protocol: None } - }, - Message::Protocol(p) => { - let protocol = this.protocols.iter().find_map(|(name, proto)| { - if &p == proto { - Some(name.clone()) - } else { - None - } - }); - - let message = if protocol.is_some() { - tracing::debug!("Listener: confirming protocol: {}", p); - Message::Protocol(p.clone()) - } else { - tracing::debug!( - "Listener: rejecting protocol: {}", - String::from_utf8_lossy(p.as_ref()) - ); - Message::NotAvailable - }; - - *this.state = State::SendMessage { io, message, protocol }; - }, - _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), - } - }, - - State::SendMessage { mut io, message, protocol } => { - match Pin::new(&mut io).poll_ready(cx) { - Poll::Pending => { - *this.state = State::SendMessage { io, message, protocol }; - return Poll::Pending; - }, - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), - } - - if let Message::NotAvailable = &message { - *this.last_sent_na = true; - } else { - *this.last_sent_na = false; - } - - if let Err(err) = Pin::new(&mut io).start_send(message) { - return Poll::Ready(Err(From::from(err))); - } - - *this.state = State::Flush { io, protocol }; - }, - - State::Flush { mut io, protocol } => { - match Pin::new(&mut io).poll_flush(cx) { - Poll::Pending => { - *this.state = State::Flush { io, protocol }; - return Poll::Pending; - }, - Poll::Ready(Ok(())) => { - // If a protocol has been selected, finish negotiation. - // Otherwise expect to receive another message. - match protocol { - Some(protocol) => { - tracing::debug!( - "Listener: sent confirmed protocol: {}", - String::from_utf8_lossy(protocol.as_ref()) - ); - let io = Negotiated::completed(io.into_inner()); - return Poll::Ready(Ok((protocol, io))); - }, - None => *this.state = State::RecvMessage { io }, - } - }, - Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), - } - }, - - State::Done => panic!("State::poll called after completion"), - } - } - } + type Output = Result<(N, Negotiated), NegotiationError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + loop { + match mem::replace(this.state, State::Done) { + State::RecvHeader { mut io } => { + match io.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(Message::Header(h)))) => match h { + HeaderLine::V1 => *this.state = State::SendHeader { io }, + }, + Poll::Ready(Some(Ok(_))) => + return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))), + // Treat EOF error as [`NegotiationError::Failed`], not as + // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O + // stream as a permissible way to "gracefully" fail a negotiation. + Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), + Poll::Pending => { + *this.state = State::RecvHeader { io }; + return Poll::Pending; + } + } + } + + State::SendHeader { mut io } => { + match Pin::new(&mut io).poll_ready(cx) { + Poll::Pending => { + *this.state = State::SendHeader { io }; + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + + let msg = Message::Header(HeaderLine::V1); + if let Err(err) = Pin::new(&mut io).start_send(msg) { + return Poll::Ready(Err(From::from(err))); + } + + *this.state = State::Flush { io, protocol: None }; + } + + State::RecvMessage { mut io } => { + let msg = match Pin::new(&mut io).poll_next(cx) { + Poll::Ready(Some(Ok(msg))) => msg, + // Treat EOF error as [`NegotiationError::Failed`], not as + // [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O + // stream as a permissible way to "gracefully" fail a negotiation. + // + // This is e.g. important when a listener rejects a protocol with + // [`Message::NotAvailable`] and the dialer does not have alternative + // protocols to propose. Then the dialer will stop the negotiation and drop + // the corresponding stream. As a listener this EOF should be interpreted as + // a failed negotiation. + Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)), + Poll::Pending => { + *this.state = State::RecvMessage { io }; + return Poll::Pending; + } + Poll::Ready(Some(Err(err))) => { + if *this.last_sent_na { + // When we read garbage or EOF after having already rejected a + // protocol, the dialer is most likely using `V1Lazy` and has + // optimistically settled on this protocol, so this is really a + // failed negotiation, not a protocol violation. In this case + // the dialer also raises `NegotiationError::Failed` when finally + // reading the `N/A` response. + if let ProtocolError::InvalidMessage = &err { + tracing::trace!( + target: LOG_TARGET, + "Listener: Negotiation failed with invalid \ + message after protocol rejection." + ); + return Poll::Ready(Err(NegotiationError::Failed)); + } + if let ProtocolError::IoError(e) = &err { + if e.kind() == std::io::ErrorKind::UnexpectedEof { + tracing::trace!( + target: LOG_TARGET, + "Listener: Negotiation failed with EOF \ + after protocol rejection." + ); + return Poll::Ready(Err(NegotiationError::Failed)); + } + } + } + + return Poll::Ready(Err(From::from(err))); + } + }; + + match msg { + Message::ListProtocols => { + let supported = + this.protocols.iter().map(|(_, p)| p).cloned().collect(); + let message = Message::Protocols(supported); + *this.state = State::SendMessage { + io, + message, + protocol: None, + } + } + Message::Protocol(p) => { + let protocol = this.protocols.iter().find_map(|(name, proto)| { + if &p == proto { + Some(name.clone()) + } else { + None + } + }); + + let message = if protocol.is_some() { + tracing::debug!("Listener: confirming protocol: {}", p); + Message::Protocol(p.clone()) + } else { + tracing::debug!( + "Listener: rejecting protocol: {}", + String::from_utf8_lossy(p.as_ref()) + ); + Message::NotAvailable + }; + + *this.state = State::SendMessage { + io, + message, + protocol, + }; + } + _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())), + } + } + + State::SendMessage { + mut io, + message, + protocol, + } => { + match Pin::new(&mut io).poll_ready(cx) { + Poll::Pending => { + *this.state = State::SendMessage { + io, + message, + protocol, + }; + return Poll::Pending; + } + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + + if let Message::NotAvailable = &message { + *this.last_sent_na = true; + } else { + *this.last_sent_na = false; + } + + if let Err(err) = Pin::new(&mut io).start_send(message) { + return Poll::Ready(Err(From::from(err))); + } + + *this.state = State::Flush { io, protocol }; + } + + State::Flush { mut io, protocol } => { + match Pin::new(&mut io).poll_flush(cx) { + Poll::Pending => { + *this.state = State::Flush { io, protocol }; + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + // If a protocol has been selected, finish negotiation. + // Otherwise expect to receive another message. + match protocol { + Some(protocol) => { + tracing::debug!( + "Listener: sent confirmed protocol: {}", + String::from_utf8_lossy(protocol.as_ref()) + ); + let io = Negotiated::completed(io.into_inner()); + return Poll::Ready(Ok((protocol, io))); + } + None => *this.state = State::RecvMessage { io }, + } + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + } + + State::Done => panic!("State::poll called after completion"), + } + } + } } /// Negotiate protocols for listener. @@ -297,68 +328,68 @@ where /// locally available protocols. If a match is found, return an encoded multistream-select /// response and the negotiated protocol. If parsing fails or no match is found, return an error. pub fn listener_negotiate<'a>( - supported_protocols: &'a mut impl Iterator, - payload: Bytes, + supported_protocols: &'a mut impl Iterator, + payload: Bytes, ) -> crate::Result<(ProtocolName, BytesMut)> { - let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)? - else { - return Err(Error::NegotiationError(error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - ))); - }; - - // skip the multistream-select header because it's not part of user protocols but verify it's - // present - let mut protocol_iter = protocols.into_iter(); - let header = - Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header"); - - if !std::matches!(protocol_iter.next(), Some(header)) { - return Err(Error::NegotiationError(error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - ))); - } - - for protocol in protocol_iter { - for supported in &mut *supported_protocols { - if protocol.as_ref() == supported.as_bytes() { - // encode `/multistream-select/1.0.0` header - let mut bytes = BytesMut::with_capacity(64); - let message = Message::Header(HeaderLine::V1); - let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData)?; - let mut header = UnsignedVarint::encode(bytes)?; - - // encode negotiated protocol - let mut proto_bytes = BytesMut::with_capacity(512); - let message = Message::Protocol(protocol); - let _ = message.encode(&mut proto_bytes).map_err(|_| Error::InvalidData)?; - let proto_bytes = UnsignedVarint::encode(proto_bytes)?; - - header.append(&mut proto_bytes.into()); - - return Ok((supported.clone(), BytesMut::from(&header[..]))); - } - } - } - - Err(Error::NegotiationError(error::NegotiationError::MultistreamSelectError( - NegotiationError::Failed, - ))) + let Message::Protocols(protocols) = Message::decode(payload).map_err(|_| Error::InvalidData)? + else { + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )); + }; + + // skip the multistream-select header because it's not part of user protocols but verify it's + // present + let mut protocol_iter = protocols.into_iter(); + let header = + Protocol::try_from(&b"/multistream/1.0.0"[..]).expect("valid multitstream-select header"); + + if !std::matches!(protocol_iter.next(), Some(header)) { + return Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )); + } + + for protocol in protocol_iter { + for supported in &mut *supported_protocols { + if protocol.as_ref() == supported.as_bytes() { + // encode `/multistream-select/1.0.0` header + let mut bytes = BytesMut::with_capacity(64); + let message = Message::Header(HeaderLine::V1); + let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData)?; + let mut header = UnsignedVarint::encode(bytes)?; + + // encode negotiated protocol + let mut proto_bytes = BytesMut::with_capacity(512); + let message = Message::Protocol(protocol); + let _ = message.encode(&mut proto_bytes).map_err(|_| Error::InvalidData)?; + let proto_bytes = UnsignedVarint::encode(proto_bytes)?; + + header.append(&mut proto_bytes.into()); + + return Ok((supported.clone(), BytesMut::from(&header[..]))); + } + } + } + + Err(Error::NegotiationError( + error::NegotiationError::MultistreamSelectError(NegotiationError::Failed), + )) } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn listener_negotiate_works() {} + #[test] + fn listener_negotiate_works() {} - #[test] - fn invalid_message_offered() {} + #[test] + fn invalid_message_offered() {} - #[test] - fn no_supported_protocol() {} + #[test] + fn no_supported_protocol() {} - #[test] - fn multistream_select_header_missing() {} + #[test] + fn multistream_select_header_missing() {} } diff --git a/src/multistream_select/mod.rs b/src/multistream_select/mod.rs index f20b816e..a6949397 100644 --- a/src/multistream_select/mod.rs +++ b/src/multistream_select/mod.rs @@ -76,56 +76,56 @@ mod negotiated; mod protocol; pub use crate::multistream_select::{ - dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult}, - listener_select::{listener_negotiate, listener_select_proto, ListenerSelectFuture}, - negotiated::{Negotiated, NegotiatedComplete, NegotiationError}, - protocol::{HeaderLine, Message, Protocol, ProtocolError}, + dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult}, + listener_select::{listener_negotiate, listener_select_proto, ListenerSelectFuture}, + negotiated::{Negotiated, NegotiatedComplete, NegotiationError}, + protocol::{HeaderLine, Message, Protocol, ProtocolError}, }; /// Supported multistream-select versions. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Version { - /// Version 1 of the multistream-select protocol. See [1] and [2]. - /// - /// [1]: https://github.com/libp2p/specs/blob/master/connections/README.md#protocol-negotiation - /// [2]: https://github.com/multiformats/multistream-select - V1, - /// A "lazy" variant of version 1 that is identical on the wire but whereby - /// the dialer delays flushing protocol negotiation data in order to combine - /// it with initial application data, thus performing 0-RTT negotiation. - /// - /// This strategy is only applicable for the node with the role of "dialer" - /// in the negotiation and only if the dialer supports just a single - /// application protocol. In that case the dialer immedidately "settles" - /// on that protocol, buffering the negotiation messages to be sent - /// with the first round of application protocol data (or an attempt - /// is made to read from the `Negotiated` I/O stream). - /// - /// A listener will behave identically to `V1`. This ensures interoperability with `V1`. - /// Notably, it will immediately send the multistream header as well as the protocol - /// confirmation, resulting in multiple frames being sent on the underlying transport. - /// Nevertheless, if the listener supports the protocol that the dialer optimistically - /// settled on, it can be a 0-RTT negotiation. - /// - /// > **Note**: `V1Lazy` is specific to `rust-libp2p`. The wire protocol is identical to `V1` - /// > and generally interoperable with peers only supporting `V1`. Nevertheless, there is a - /// > pitfall that is rarely encountered: When nesting multiple protocol negotiations, the - /// > listener should either be known to support all of the dialer's optimistically chosen - /// > protocols or there is must be no intermediate protocol without a payload and none of - /// > the protocol payloads must have the potential for being mistaken for a multistream-select - /// > protocol message. This avoids rare edge-cases whereby the listener may not recognize - /// > upgrade boundaries and erroneously process a request despite not supporting one of - /// > the intermediate protocols that the dialer committed to. See [1] and [2]. - /// - /// [1]: https://github.com/multiformats/go-multistream/issues/20 - /// [2]: https://github.com/libp2p/rust-libp2p/pull/1212 - V1Lazy, - // Draft: https://github.com/libp2p/specs/pull/95 - // V2, + /// Version 1 of the multistream-select protocol. See [1] and [2]. + /// + /// [1]: https://github.com/libp2p/specs/blob/master/connections/README.md#protocol-negotiation + /// [2]: https://github.com/multiformats/multistream-select + V1, + /// A "lazy" variant of version 1 that is identical on the wire but whereby + /// the dialer delays flushing protocol negotiation data in order to combine + /// it with initial application data, thus performing 0-RTT negotiation. + /// + /// This strategy is only applicable for the node with the role of "dialer" + /// in the negotiation and only if the dialer supports just a single + /// application protocol. In that case the dialer immedidately "settles" + /// on that protocol, buffering the negotiation messages to be sent + /// with the first round of application protocol data (or an attempt + /// is made to read from the `Negotiated` I/O stream). + /// + /// A listener will behave identically to `V1`. This ensures interoperability with `V1`. + /// Notably, it will immediately send the multistream header as well as the protocol + /// confirmation, resulting in multiple frames being sent on the underlying transport. + /// Nevertheless, if the listener supports the protocol that the dialer optimistically + /// settled on, it can be a 0-RTT negotiation. + /// + /// > **Note**: `V1Lazy` is specific to `rust-libp2p`. The wire protocol is identical to `V1` + /// > and generally interoperable with peers only supporting `V1`. Nevertheless, there is a + /// > pitfall that is rarely encountered: When nesting multiple protocol negotiations, the + /// > listener should either be known to support all of the dialer's optimistically chosen + /// > protocols or there is must be no intermediate protocol without a payload and none of + /// > the protocol payloads must have the potential for being mistaken for a multistream-select + /// > protocol message. This avoids rare edge-cases whereby the listener may not recognize + /// > upgrade boundaries and erroneously process a request despite not supporting one of + /// > the intermediate protocols that the dialer committed to. See [1] and [2]. + /// + /// [1]: https://github.com/multiformats/go-multistream/issues/20 + /// [2]: https://github.com/libp2p/rust-libp2p/pull/1212 + V1Lazy, + // Draft: https://github.com/libp2p/specs/pull/95 + // V2, } impl Default for Version { - fn default() -> Self { - Version::V1 - } + fn default() -> Self { + Version::V1 + } } diff --git a/src/multistream_select/negotiated.rs b/src/multistream_select/negotiated.rs index c4d5077d..450bdae3 100644 --- a/src/multistream_select/negotiated.rs +++ b/src/multistream_select/negotiated.rs @@ -19,20 +19,20 @@ // DEALINGS IN THE SOFTWARE. use crate::multistream_select::protocol::{ - HeaderLine, Message, MessageReader, Protocol, ProtocolError, + HeaderLine, Message, MessageReader, Protocol, ProtocolError, }; use futures::{ - io::{IoSlice, IoSliceMut}, - prelude::*, - ready, + io::{IoSlice, IoSliceMut}, + prelude::*, + ready, }; use pin_project::pin_project; use std::{ - error::Error, - fmt, io, mem, - pin::Pin, - task::{Context, Poll}, + error::Error, + fmt, io, mem, + pin::Pin, + task::{Context, Poll}, }; const LOG_TARGET: &str = "litep2p::multistream-select"; @@ -51,324 +51,346 @@ const LOG_TARGET: &str = "litep2p::multistream-select"; #[pin_project] #[derive(Debug)] pub struct Negotiated { - #[pin] - state: State, + #[pin] + state: State, } /// A `Future` that waits on the completion of protocol negotiation. #[derive(Debug)] pub struct NegotiatedComplete { - inner: Option>, + inner: Option>, } impl Future for NegotiatedComplete where - // `Unpin` is required not because of - // implementation details but because we produce - // the `Negotiated` as the output of the - // future. - TInner: AsyncRead + AsyncWrite + Unpin, + // `Unpin` is required not because of + // implementation details but because we produce + // the `Negotiated` as the output of the + // future. + TInner: AsyncRead + AsyncWrite + Unpin, { - type Output = Result, NegotiationError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); - match Negotiated::poll(Pin::new(&mut io), cx) { - Poll::Pending => { - self.inner = Some(io); - Poll::Pending - }, - Poll::Ready(Ok(())) => Poll::Ready(Ok(io)), - Poll::Ready(Err(err)) => { - self.inner = Some(io); - Poll::Ready(Err(err)) - }, - } - } + type Output = Result, NegotiationError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); + match Negotiated::poll(Pin::new(&mut io), cx) { + Poll::Pending => { + self.inner = Some(io); + Poll::Pending + } + Poll::Ready(Ok(())) => Poll::Ready(Ok(io)), + Poll::Ready(Err(err)) => { + self.inner = Some(io); + Poll::Ready(Err(err)) + } + } + } } impl Negotiated { - /// Creates a `Negotiated` in state [`State::Completed`]. - pub(crate) fn completed(io: TInner) -> Self { - Negotiated { state: State::Completed { io } } - } - - /// Creates a `Negotiated` in state [`State::Expecting`] that is still - /// expecting confirmation of the given `protocol`. - pub(crate) fn expecting( - io: MessageReader, - protocol: Protocol, - header: Option, - ) -> Self { - Negotiated { state: State::Expecting { io, protocol, header } } - } - - pub fn inner(self) -> TInner { - match self.state { - State::Completed { io } => io, - _ => panic!("stream is not negotiated"), - } - } - - /// Polls the `Negotiated` for completion. - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> - where - TInner: AsyncRead + AsyncWrite + Unpin, - { - // Flush any pending negotiation data. - match self.as_mut().poll_flush(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => { - // If the remote closed the stream, it is important to still - // continue reading the data that was sent, if any. - if e.kind() != io::ErrorKind::WriteZero { - return Poll::Ready(Err(e.into())); - } - }, - } - - let mut this = self.project(); - - if let StateProj::Completed { .. } = this.state.as_mut().project() { - return Poll::Ready(Ok(())); - } - - // Read outstanding protocol negotiation messages. - loop { - match mem::replace(&mut *this.state, State::Invalid) { - State::Expecting { mut io, header, protocol } => { - let msg = match Pin::new(&mut io).poll_next(cx)? { - Poll::Ready(Some(msg)) => msg, - Poll::Pending => { - *this.state = State::Expecting { io, header, protocol }; - return Poll::Pending; - }, - Poll::Ready(None) => { - return Poll::Ready(Err(ProtocolError::IoError( - io::ErrorKind::UnexpectedEof.into(), - ) - .into())); - }, - }; - - if let Message::Header(h) = &msg { - if Some(h) == header.as_ref() { - *this.state = State::Expecting { io, protocol, header: None }; - continue; - } - } - - if let Message::Protocol(p) = &msg { - if p.as_ref() == protocol.as_ref() { - tracing::debug!( - target: LOG_TARGET, - "Negotiated: Received confirmation for protocol: {}", - p - ); - *this.state = State::Completed { io: io.into_inner() }; - return Poll::Ready(Ok(())); - } - } - - return Poll::Ready(Err(NegotiationError::Failed)); - }, - - _ => panic!("Negotiated: Invalid state"), - } - } - } - - /// Returns a [`NegotiatedComplete`] future that waits for protocol - /// negotiation to complete. - pub fn complete(self) -> NegotiatedComplete { - NegotiatedComplete { inner: Some(self) } - } + /// Creates a `Negotiated` in state [`State::Completed`]. + pub(crate) fn completed(io: TInner) -> Self { + Negotiated { + state: State::Completed { io }, + } + } + + /// Creates a `Negotiated` in state [`State::Expecting`] that is still + /// expecting confirmation of the given `protocol`. + pub(crate) fn expecting( + io: MessageReader, + protocol: Protocol, + header: Option, + ) -> Self { + Negotiated { + state: State::Expecting { + io, + protocol, + header, + }, + } + } + + pub fn inner(self) -> TInner { + match self.state { + State::Completed { io } => io, + _ => panic!("stream is not negotiated"), + } + } + + /// Polls the `Negotiated` for completion. + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> + where + TInner: AsyncRead + AsyncWrite + Unpin, + { + // Flush any pending negotiation data. + match self.as_mut().poll_flush(cx) { + Poll::Ready(Ok(())) => {} + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => { + // If the remote closed the stream, it is important to still + // continue reading the data that was sent, if any. + if e.kind() != io::ErrorKind::WriteZero { + return Poll::Ready(Err(e.into())); + } + } + } + + let mut this = self.project(); + + if let StateProj::Completed { .. } = this.state.as_mut().project() { + return Poll::Ready(Ok(())); + } + + // Read outstanding protocol negotiation messages. + loop { + match mem::replace(&mut *this.state, State::Invalid) { + State::Expecting { + mut io, + header, + protocol, + } => { + let msg = match Pin::new(&mut io).poll_next(cx)? { + Poll::Ready(Some(msg)) => msg, + Poll::Pending => { + *this.state = State::Expecting { + io, + header, + protocol, + }; + return Poll::Pending; + } + Poll::Ready(None) => { + return Poll::Ready(Err(ProtocolError::IoError( + io::ErrorKind::UnexpectedEof.into(), + ) + .into())); + } + }; + + if let Message::Header(h) = &msg { + if Some(h) == header.as_ref() { + *this.state = State::Expecting { + io, + protocol, + header: None, + }; + continue; + } + } + + if let Message::Protocol(p) = &msg { + if p.as_ref() == protocol.as_ref() { + tracing::debug!( + target: LOG_TARGET, + "Negotiated: Received confirmation for protocol: {}", + p + ); + *this.state = State::Completed { + io: io.into_inner(), + }; + return Poll::Ready(Ok(())); + } + } + + return Poll::Ready(Err(NegotiationError::Failed)); + } + + _ => panic!("Negotiated: Invalid state"), + } + } + } + + /// Returns a [`NegotiatedComplete`] future that waits for protocol + /// negotiation to complete. + pub fn complete(self) -> NegotiatedComplete { + NegotiatedComplete { inner: Some(self) } + } } /// The states of a `Negotiated` I/O stream. #[pin_project(project = StateProj)] #[derive(Debug)] enum State { - /// In this state, a `Negotiated` is still expecting to - /// receive confirmation of the protocol it has optimistically - /// settled on. - Expecting { - /// The underlying I/O stream. - #[pin] - io: MessageReader, - /// The expected negotiation header/preamble (i.e. multistream-select version), - /// if one is still expected to be received. - header: Option, - /// The expected application protocol (i.e. name and version). - protocol: Protocol, - }, - - /// In this state, a protocol has been agreed upon and I/O - /// on the underlying stream can commence. - Completed { - #[pin] - io: R, - }, - - /// Temporary state while moving the `io` resource from - /// `Expecting` to `Completed`. - Invalid, + /// In this state, a `Negotiated` is still expecting to + /// receive confirmation of the protocol it has optimistically + /// settled on. + Expecting { + /// The underlying I/O stream. + #[pin] + io: MessageReader, + /// The expected negotiation header/preamble (i.e. multistream-select version), + /// if one is still expected to be received. + header: Option, + /// The expected application protocol (i.e. name and version). + protocol: Protocol, + }, + + /// In this state, a protocol has been agreed upon and I/O + /// on the underlying stream can commence. + Completed { + #[pin] + io: R, + }, + + /// Temporary state while moving the `io` resource from + /// `Expecting` to `Completed`. + Invalid, } impl AsyncRead for Negotiated where - TInner: AsyncRead + AsyncWrite + Unpin, + TInner: AsyncRead + AsyncWrite + Unpin, { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - if let StateProj::Completed { io } = self.as_mut().project().state.project() { - // If protocol negotiation is complete, commence with reading. - return io.poll_read(cx, buf); - } - - // Poll the `Negotiated`, driving protocol negotiation to completion, - // including flushing of any remaining data. - match self.as_mut().poll(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), - } - } - } - - // TODO: implement once method is stabilized in the futures crate - /*unsafe fn initializer(&self) -> Initializer { - match &self.state { - State::Completed { io, .. } => io.initializer(), - State::Expecting { io, .. } => io.inner_ref().initializer(), - State::Invalid => panic!("Negotiated: Invalid state"), - } - }*/ - - fn poll_read_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &mut [IoSliceMut<'_>], - ) -> Poll> { - loop { - if let StateProj::Completed { io } = self.as_mut().project().state.project() { - // If protocol negotiation is complete, commence with reading. - return io.poll_read_vectored(cx, bufs); - } - - // Poll the `Negotiated`, driving protocol negotiation to completion, - // including flushing of any remaining data. - match self.as_mut().poll(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), - } - } - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + if let StateProj::Completed { io } = self.as_mut().project().state.project() { + // If protocol negotiation is complete, commence with reading. + return io.poll_read(cx, buf); + } + + // Poll the `Negotiated`, driving protocol negotiation to completion, + // including flushing of any remaining data. + match self.as_mut().poll(cx) { + Poll::Ready(Ok(())) => {} + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + } + } + + // TODO: implement once method is stabilized in the futures crate + /*unsafe fn initializer(&self) -> Initializer { + match &self.state { + State::Completed { io, .. } => io.initializer(), + State::Expecting { io, .. } => io.inner_ref().initializer(), + State::Invalid => panic!("Negotiated: Invalid state"), + } + }*/ + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + loop { + if let StateProj::Completed { io } = self.as_mut().project().state.project() { + // If protocol negotiation is complete, commence with reading. + return io.poll_read_vectored(cx, bufs); + } + + // Poll the `Negotiated`, driving protocol negotiation to completion, + // including flushing of any remaining data. + match self.as_mut().poll(cx) { + Poll::Ready(Ok(())) => {} + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))), + } + } + } } impl AsyncWrite for Negotiated where - TInner: AsyncWrite + AsyncRead + Unpin, + TInner: AsyncWrite + AsyncRead + Unpin, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.project().state.project() { - StateProj::Completed { io } => io.poll_write(cx, buf), - StateProj::Expecting { io, .. } => io.poll_write(cx, buf), - StateProj::Invalid => panic!("Negotiated: Invalid state"), - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project().state.project() { - StateProj::Completed { io } => io.poll_flush(cx), - StateProj::Expecting { io, .. } => io.poll_flush(cx), - StateProj::Invalid => panic!("Negotiated: Invalid state"), - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // Ensure all data has been flushed and expected negotiation messages - // have been received. - ready!(self.as_mut().poll(cx).map_err(Into::::into)?); - ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); - - // Continue with the shutdown of the underlying I/O stream. - match self.project().state.project() { - StateProj::Completed { io, .. } => io.poll_close(cx), - StateProj::Expecting { io, .. } => io.poll_close(cx), - StateProj::Invalid => panic!("Negotiated: Invalid state"), - } - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - match self.project().state.project() { - StateProj::Completed { io } => io.poll_write_vectored(cx, bufs), - StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs), - StateProj::Invalid => panic!("Negotiated: Invalid state"), - } - } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project().state.project() { + StateProj::Completed { io } => io.poll_write(cx, buf), + StateProj::Expecting { io, .. } => io.poll_write(cx, buf), + StateProj::Invalid => panic!("Negotiated: Invalid state"), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().state.project() { + StateProj::Completed { io } => io.poll_flush(cx), + StateProj::Expecting { io, .. } => io.poll_flush(cx), + StateProj::Invalid => panic!("Negotiated: Invalid state"), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Ensure all data has been flushed and expected negotiation messages + // have been received. + ready!(self.as_mut().poll(cx).map_err(Into::::into)?); + ready!(self.as_mut().poll_flush(cx).map_err(Into::::into)?); + + // Continue with the shutdown of the underlying I/O stream. + match self.project().state.project() { + StateProj::Completed { io, .. } => io.poll_close(cx), + StateProj::Expecting { io, .. } => io.poll_close(cx), + StateProj::Invalid => panic!("Negotiated: Invalid state"), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match self.project().state.project() { + StateProj::Completed { io } => io.poll_write_vectored(cx, bufs), + StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs), + StateProj::Invalid => panic!("Negotiated: Invalid state"), + } + } } /// Error that can happen when negotiating a protocol with the remote. #[derive(Debug)] pub enum NegotiationError { - /// A protocol error occurred during the negotiation. - ProtocolError(ProtocolError), + /// A protocol error occurred during the negotiation. + ProtocolError(ProtocolError), - /// Protocol negotiation failed because no protocol could be agreed upon. - Failed, + /// Protocol negotiation failed because no protocol could be agreed upon. + Failed, } impl From for NegotiationError { - fn from(err: ProtocolError) -> NegotiationError { - NegotiationError::ProtocolError(err) - } + fn from(err: ProtocolError) -> NegotiationError { + NegotiationError::ProtocolError(err) + } } impl From for NegotiationError { - fn from(err: io::Error) -> NegotiationError { - ProtocolError::from(err).into() - } + fn from(err: io::Error) -> NegotiationError { + ProtocolError::from(err).into() + } } impl From for io::Error { - fn from(err: NegotiationError) -> io::Error { - if let NegotiationError::ProtocolError(e) = err { - return e.into(); - } - io::Error::new(io::ErrorKind::Other, err) - } + fn from(err: NegotiationError) -> io::Error { + if let NegotiationError::ProtocolError(e) = err { + return e.into(); + } + io::Error::new(io::ErrorKind::Other, err) + } } impl Error for NegotiationError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - NegotiationError::ProtocolError(err) => Some(err), - _ => None, - } - } + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + NegotiationError::ProtocolError(err) => Some(err), + _ => None, + } + } } impl fmt::Display for NegotiationError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - NegotiationError::ProtocolError(p) => - fmt.write_fmt(format_args!("Protocol error: {p}")), - NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."), - } - } + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + NegotiationError::ProtocolError(p) => + fmt.write_fmt(format_args!("Protocol error: {p}")), + NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."), + } + } } diff --git a/src/multistream_select/protocol.rs b/src/multistream_select/protocol.rs index 515bb683..bf710850 100644 --- a/src/multistream_select/protocol.rs +++ b/src/multistream_select/protocol.rs @@ -26,18 +26,18 @@ //! `MessageReader`. use crate::multistream_select::{ - length_delimited::{LengthDelimited, LengthDelimitedReader}, - Version, + length_delimited::{LengthDelimited, LengthDelimitedReader}, + Version, }; use bytes::{BufMut, Bytes, BytesMut}; use futures::{io::IoSlice, prelude::*, ready}; use std::{ - convert::TryFrom, - error::Error, - fmt, io, - pin::Pin, - task::{Context, Poll}, + convert::TryFrom, + error::Error, + fmt, io, + pin::Pin, + task::{Context, Poll}, }; use unsigned_varint as uvi; @@ -58,16 +58,16 @@ const LOG_TARGET: &str = "litep2p::multistream-select"; /// Every [`Version`] has a corresponding header line. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum HeaderLine { - /// The `/multistream/1.0.0` header line. - V1, + /// The `/multistream/1.0.0` header line. + V1, } impl From for HeaderLine { - fn from(v: Version) -> HeaderLine { - match v { - Version::V1 | Version::V1Lazy => HeaderLine::V1, - } - } + fn from(v: Version) -> HeaderLine { + match v { + Version::V1 | Version::V1Lazy => HeaderLine::V1, + } + } } /// A protocol (name) exchanged during protocol negotiation. @@ -75,34 +75,34 @@ impl From for HeaderLine { pub struct Protocol(Bytes); impl AsRef<[u8]> for Protocol { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } } impl TryFrom for Protocol { - type Error = ProtocolError; - - fn try_from(value: Bytes) -> Result { - if !value.as_ref().starts_with(b"/") { - return Err(ProtocolError::InvalidProtocol); - } - Ok(Protocol(value)) - } + type Error = ProtocolError; + + fn try_from(value: Bytes) -> Result { + if !value.as_ref().starts_with(b"/") { + return Err(ProtocolError::InvalidProtocol); + } + Ok(Protocol(value)) + } } impl TryFrom<&[u8]> for Protocol { - type Error = ProtocolError; + type Error = ProtocolError; - fn try_from(value: &[u8]) -> Result { - Self::try_from(Bytes::copy_from_slice(value)) - } + fn try_from(value: &[u8]) -> Result { + Self::try_from(Bytes::copy_from_slice(value)) + } } impl fmt::Display for Protocol { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", String::from_utf8_lossy(&self.0)) - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", String::from_utf8_lossy(&self.0)) + } } /// A multistream-select protocol message. @@ -111,198 +111,202 @@ impl fmt::Display for Protocol { /// of agreeing on a application-layer protocol to use on an I/O stream. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Message { - /// A header message identifies the multistream-select protocol - /// that the sender wishes to speak. - Header(HeaderLine), - /// A protocol message identifies a protocol request or acknowledgement. - Protocol(Protocol), - /// A message through which a peer requests the complete list of - /// supported protocols from the remote. - ListProtocols, - /// A message listing all supported protocols of a peer. - Protocols(Vec), - /// A message signaling that a requested protocol is not available. - NotAvailable, + /// A header message identifies the multistream-select protocol + /// that the sender wishes to speak. + Header(HeaderLine), + /// A protocol message identifies a protocol request or acknowledgement. + Protocol(Protocol), + /// A message through which a peer requests the complete list of + /// supported protocols from the remote. + ListProtocols, + /// A message listing all supported protocols of a peer. + Protocols(Vec), + /// A message signaling that a requested protocol is not available. + NotAvailable, } impl Message { - /// Encodes a `Message` into its byte representation. - pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> { - match self { - Message::Header(HeaderLine::V1) => { - dest.reserve(MSG_MULTISTREAM_1_0.len()); - dest.put(MSG_MULTISTREAM_1_0); - Ok(()) - }, - Message::Protocol(p) => { - let len = p.0.as_ref().len() + 1; // + 1 for \n - dest.reserve(len); - dest.put(p.0.as_ref()); - dest.put_u8(b'\n'); - Ok(()) - }, - Message::ListProtocols => { - dest.reserve(MSG_LS.len()); - dest.put(MSG_LS); - Ok(()) - }, - Message::Protocols(ps) => { - let mut buf = uvi::encode::usize_buffer(); - let mut encoded = Vec::with_capacity(ps.len()); - for p in ps { - encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); // +1 for '\n' - encoded.extend_from_slice(p.0.as_ref()); - encoded.push(b'\n') - } - encoded.push(b'\n'); - dest.reserve(encoded.len()); - dest.put(encoded.as_ref()); - Ok(()) - }, - Message::NotAvailable => { - dest.reserve(MSG_PROTOCOL_NA.len()); - dest.put(MSG_PROTOCOL_NA); - Ok(()) - }, - } - } - - /// Decodes a `Message` from its byte representation. - pub fn decode(mut msg: Bytes) -> Result { - if msg == MSG_MULTISTREAM_1_0 { - return Ok(Message::Header(HeaderLine::V1)); - } - - if msg == MSG_PROTOCOL_NA { - return Ok(Message::NotAvailable); - } - - if msg == MSG_LS { - return Ok(Message::ListProtocols); - } - - // If it starts with a `/`, ends with a line feed without any - // other line feeds in-between, it must be a protocol name. - if msg.first() == Some(&b'/') && - msg.last() == Some(&b'\n') && - !msg[..msg.len() - 1].contains(&b'\n') - { - let p = Protocol::try_from(msg.split_to(msg.len() - 1))?; - return Ok(Message::Protocol(p)); - } - - // At this point, it must be an `ls` response, i.e. one or more - // length-prefixed, newline-delimited protocol names. - let mut protocols = Vec::new(); - let mut remaining: &[u8] = &msg; - loop { - // A well-formed message must be terminated with a newline. - // TODO: don't do this - if remaining == [b'\n'] || remaining.is_empty() { - break; - } else if protocols.len() == MAX_PROTOCOLS { - return Err(ProtocolError::TooManyProtocols); - } - - // Decode the length of the next protocol name and check that - // it ends with a line feed. - let (len, tail) = uvi::decode::usize(remaining)?; - if len == 0 || len > tail.len() || tail[len - 1] != b'\n' { - return Err(ProtocolError::InvalidMessage); - } - - // Parse the protocol name. - let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?; - protocols.push(p); - - // Skip ahead to the next protocol. - remaining = &tail[len..]; - } - - Ok(Message::Protocols(protocols)) - } + /// Encodes a `Message` into its byte representation. + pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> { + match self { + Message::Header(HeaderLine::V1) => { + dest.reserve(MSG_MULTISTREAM_1_0.len()); + dest.put(MSG_MULTISTREAM_1_0); + Ok(()) + } + Message::Protocol(p) => { + let len = p.0.as_ref().len() + 1; // + 1 for \n + dest.reserve(len); + dest.put(p.0.as_ref()); + dest.put_u8(b'\n'); + Ok(()) + } + Message::ListProtocols => { + dest.reserve(MSG_LS.len()); + dest.put(MSG_LS); + Ok(()) + } + Message::Protocols(ps) => { + let mut buf = uvi::encode::usize_buffer(); + let mut encoded = Vec::with_capacity(ps.len()); + for p in ps { + encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); // +1 for '\n' + encoded.extend_from_slice(p.0.as_ref()); + encoded.push(b'\n') + } + encoded.push(b'\n'); + dest.reserve(encoded.len()); + dest.put(encoded.as_ref()); + Ok(()) + } + Message::NotAvailable => { + dest.reserve(MSG_PROTOCOL_NA.len()); + dest.put(MSG_PROTOCOL_NA); + Ok(()) + } + } + } + + /// Decodes a `Message` from its byte representation. + pub fn decode(mut msg: Bytes) -> Result { + if msg == MSG_MULTISTREAM_1_0 { + return Ok(Message::Header(HeaderLine::V1)); + } + + if msg == MSG_PROTOCOL_NA { + return Ok(Message::NotAvailable); + } + + if msg == MSG_LS { + return Ok(Message::ListProtocols); + } + + // If it starts with a `/`, ends with a line feed without any + // other line feeds in-between, it must be a protocol name. + if msg.first() == Some(&b'/') + && msg.last() == Some(&b'\n') + && !msg[..msg.len() - 1].contains(&b'\n') + { + let p = Protocol::try_from(msg.split_to(msg.len() - 1))?; + return Ok(Message::Protocol(p)); + } + + // At this point, it must be an `ls` response, i.e. one or more + // length-prefixed, newline-delimited protocol names. + let mut protocols = Vec::new(); + let mut remaining: &[u8] = &msg; + loop { + // A well-formed message must be terminated with a newline. + // TODO: don't do this + if remaining == [b'\n'] || remaining.is_empty() { + break; + } else if protocols.len() == MAX_PROTOCOLS { + return Err(ProtocolError::TooManyProtocols); + } + + // Decode the length of the next protocol name and check that + // it ends with a line feed. + let (len, tail) = uvi::decode::usize(remaining)?; + if len == 0 || len > tail.len() || tail[len - 1] != b'\n' { + return Err(ProtocolError::InvalidMessage); + } + + // Parse the protocol name. + let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?; + protocols.push(p); + + // Skip ahead to the next protocol. + remaining = &tail[len..]; + } + + Ok(Message::Protocols(protocols)) + } } /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. #[pin_project::pin_project] pub struct MessageIO { - #[pin] - inner: LengthDelimited, + #[pin] + inner: LengthDelimited, } impl MessageIO { - /// Constructs a new `MessageIO` resource wrapping the given I/O stream. - pub fn new(inner: R) -> MessageIO - where - R: AsyncRead + AsyncWrite, - { - Self { inner: LengthDelimited::new(inner) } - } - - /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the - /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access - /// to the underlying I/O stream. - /// - /// This is typically done if further negotiation messages are expected to be - /// received but no more messages are written, allowing the writing of - /// follow-up protocol data to commence. - pub fn into_reader(self) -> MessageReader { - MessageReader { inner: self.inner.into_reader() } - } - - /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream. - /// - /// # Panics - /// - /// Panics if the read buffer or write buffer is not empty, meaning that an incoming - /// protocol negotiation frame has been partially read or an outgoing frame - /// has not yet been flushed. The read buffer is guaranteed to be empty whenever - /// `MessageIO::poll` returned a message. The write buffer is guaranteed to be empty - /// when the sink has been flushed. - pub fn into_inner(self) -> R { - self.inner.into_inner() - } + /// Constructs a new `MessageIO` resource wrapping the given I/O stream. + pub fn new(inner: R) -> MessageIO + where + R: AsyncRead + AsyncWrite, + { + Self { + inner: LengthDelimited::new(inner), + } + } + + /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the + /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access + /// to the underlying I/O stream. + /// + /// This is typically done if further negotiation messages are expected to be + /// received but no more messages are written, allowing the writing of + /// follow-up protocol data to commence. + pub fn into_reader(self) -> MessageReader { + MessageReader { + inner: self.inner.into_reader(), + } + } + + /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream. + /// + /// # Panics + /// + /// Panics if the read buffer or write buffer is not empty, meaning that an incoming + /// protocol negotiation frame has been partially read or an outgoing frame + /// has not yet been flushed. The read buffer is guaranteed to be empty whenever + /// `MessageIO::poll` returned a message. The write buffer is guaranteed to be empty + /// when the sink has been flushed. + pub fn into_inner(self) -> R { + self.inner.into_inner() + } } impl Sink for MessageIO where - R: AsyncWrite, + R: AsyncWrite, { - type Error = ProtocolError; + type Error = ProtocolError; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_ready(cx).map_err(From::from) - } + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx).map_err(From::from) + } - fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - let mut buf = BytesMut::new(); - item.encode(&mut buf)?; - self.project().inner.start_send(buf.freeze()).map_err(From::from) - } + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + let mut buf = BytesMut::new(); + item.encode(&mut buf)?; + self.project().inner.start_send(buf.freeze()).map_err(From::from) + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx).map_err(From::from) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx).map_err(From::from) + } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_close(cx).map_err(From::from) - } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx).map_err(From::from) + } } impl Stream for MessageIO where - R: AsyncRead, + R: AsyncRead, { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match poll_stream(self.project().inner, cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), - } - } + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match poll_stream(self.project().inner, cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + } + } } /// A `MessageReader` implements a `Stream` of `Message`s on an underlying @@ -310,141 +314,141 @@ where #[pin_project::pin_project] #[derive(Debug)] pub struct MessageReader { - #[pin] - inner: LengthDelimitedReader, + #[pin] + inner: LengthDelimitedReader, } impl MessageReader { - /// Drops the `MessageReader` resource, yielding the underlying I/O stream - /// together with the remaining write buffer containing the protocol - /// negotiation frame data that has not yet been written to the I/O stream. - /// - /// # Panics - /// - /// Panics if the read buffer or write buffer is not empty, meaning that either - /// an incoming protocol negotiation frame has been partially read, or an - /// outgoing frame has not yet been flushed. The read buffer is guaranteed to - /// be empty whenever `MessageReader::poll` returned a message. The write - /// buffer is guaranteed to be empty whenever the sink has been flushed. - pub fn into_inner(self) -> R { - self.inner.into_inner() - } + /// Drops the `MessageReader` resource, yielding the underlying I/O stream + /// together with the remaining write buffer containing the protocol + /// negotiation frame data that has not yet been written to the I/O stream. + /// + /// # Panics + /// + /// Panics if the read buffer or write buffer is not empty, meaning that either + /// an incoming protocol negotiation frame has been partially read, or an + /// outgoing frame has not yet been flushed. The read buffer is guaranteed to + /// be empty whenever `MessageReader::poll` returned a message. The write + /// buffer is guaranteed to be empty whenever the sink has been flushed. + pub fn into_inner(self) -> R { + self.inner.into_inner() + } } impl Stream for MessageReader where - R: AsyncRead, + R: AsyncRead, { - type Item = Result; + type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - poll_stream(self.project().inner, cx) - } + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_stream(self.project().inner, cx) + } } impl AsyncWrite for MessageReader where - TInner: AsyncWrite, + TInner: AsyncWrite, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.project().inner.poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().inner.poll_close(cx) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - self.project().inner.poll_write_vectored(cx, bufs) - } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } } fn poll_stream( - stream: Pin<&mut S>, - cx: &mut Context<'_>, + stream: Pin<&mut S>, + cx: &mut Context<'_>, ) -> Poll>> where - S: Stream>, + S: Stream>, { - let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) { - match Message::decode(msg) { - Ok(m) => m, - Err(err) => return Poll::Ready(Some(Err(err))), - } - } else { - return Poll::Ready(None); - }; - - tracing::trace!(target: LOG_TARGET, "Received message: {:?}", msg); - - Poll::Ready(Some(Ok(msg))) + let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) { + match Message::decode(msg) { + Ok(m) => m, + Err(err) => return Poll::Ready(Some(Err(err))), + } + } else { + return Poll::Ready(None); + }; + + tracing::trace!(target: LOG_TARGET, "Received message: {:?}", msg); + + Poll::Ready(Some(Ok(msg))) } /// A protocol error. #[derive(Debug)] pub enum ProtocolError { - /// I/O error. - IoError(io::Error), + /// I/O error. + IoError(io::Error), - /// Received an invalid message from the remote. - InvalidMessage, + /// Received an invalid message from the remote. + InvalidMessage, - /// A protocol (name) is invalid. - InvalidProtocol, + /// A protocol (name) is invalid. + InvalidProtocol, - /// Too many protocols have been returned by the remote. - TooManyProtocols, + /// Too many protocols have been returned by the remote. + TooManyProtocols, } impl From for ProtocolError { - fn from(err: io::Error) -> ProtocolError { - ProtocolError::IoError(err) - } + fn from(err: io::Error) -> ProtocolError { + ProtocolError::IoError(err) + } } impl From for io::Error { - fn from(err: ProtocolError) -> Self { - if let ProtocolError::IoError(e) = err { - return e; - } - io::ErrorKind::InvalidData.into() - } + fn from(err: ProtocolError) -> Self { + if let ProtocolError::IoError(e) = err { + return e; + } + io::ErrorKind::InvalidData.into() + } } impl From for ProtocolError { - fn from(err: uvi::decode::Error) -> ProtocolError { - Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) - } + fn from(err: uvi::decode::Error) -> ProtocolError { + Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string())) + } } impl Error for ProtocolError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match *self { - ProtocolError::IoError(ref err) => Some(err), - _ => None, - } - } + fn source(&self) -> Option<&(dyn Error + 'static)> { + match *self { + ProtocolError::IoError(ref err) => Some(err), + _ => None, + } + } } impl fmt::Display for ProtocolError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"), - ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."), - ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."), - ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."), - } - } + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"), + ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."), + ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."), + ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."), + } + } } diff --git a/src/peer_id.rs b/src/peer_id.rs index 2c408b27..69a70557 100644 --- a/src/peer_id.rs +++ b/src/peer_id.rs @@ -41,308 +41,311 @@ const MAX_INLINE_KEY_LENGTH: usize = 42; /// as specified in [specs/peer-ids](https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md). #[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct PeerId { - multihash: Multihash, + multihash: Multihash, } impl fmt::Debug for PeerId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("PeerId").field(&self.to_base58()).finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PeerId").field(&self.to_base58()).finish() + } } impl fmt::Display for PeerId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.to_base58().fmt(f) - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.to_base58().fmt(f) + } } impl PeerId { - /// Builds a `PeerId` from a public key. - pub fn from_public_key(key: &PublicKey) -> PeerId { - let key_enc = key.to_protobuf_encoding(); - - let hash_algorithm = - if key_enc.len() <= MAX_INLINE_KEY_LENGTH { Code::Identity } else { Code::Sha2_256 }; - - let multihash = hash_algorithm.digest(&key_enc); - - PeerId { multihash } - } - - /// Parses a `PeerId` from bytes. - pub fn from_bytes(data: &[u8]) -> Result { - PeerId::from_multihash(Multihash::from_bytes(data)?) - .map_err(|mh| Error::UnsupportedCode(mh.code())) - } - - /// Tries to turn a `Multihash` into a `PeerId`. - /// - /// If the multihash does not use a valid hashing algorithm for peer IDs, - /// or the hash value does not satisfy the constraints for a hashed - /// peer ID, it is returned as an `Err`. - pub fn from_multihash(multihash: Multihash) -> Result { - match Code::try_from(multihash.code()) { - Ok(Code::Sha2_256) => Ok(PeerId { multihash }), - Ok(Code::Identity) if multihash.digest().len() <= MAX_INLINE_KEY_LENGTH => - Ok(PeerId { multihash }), - _ => Err(multihash), - } - } - - /// Tries to extract a [`PeerId`] from the given [`Multiaddr`]. - /// - /// In case the given [`Multiaddr`] ends with `/p2p/`, this function - /// will return the encapsulated [`PeerId`], otherwise it will return `None`. - pub fn try_from_multiaddr(address: &Multiaddr) -> Option { - address.iter().last().and_then(|p| match p { - Protocol::P2p(hash) => PeerId::from_multihash(hash).ok(), - _ => None, - }) - } - - /// Generates a random peer ID from a cryptographically secure PRNG. - /// - /// This is useful for randomly walking on a DHT, or for testing purposes. - pub fn random() -> PeerId { - let peer_id = rand::thread_rng().gen::<[u8; 32]>(); - PeerId { - multihash: Multihash::wrap(Code::Identity.into(), &peer_id) - .expect("The digest size is never too large"), - } - } - - /// Returns a raw bytes representation of this `PeerId`. - pub fn to_bytes(&self) -> Vec { - self.multihash.to_bytes() - } - - /// Returns a base-58 encoded string of this `PeerId`. - pub fn to_base58(&self) -> String { - bs58::encode(self.to_bytes()).into_string() - } - - /// Checks whether the public key passed as parameter matches the public key of this `PeerId`. - /// - /// Returns `None` if this `PeerId`s hash algorithm is not supported when encoding the - /// given public key, otherwise `Some` boolean as the result of an equality check. - pub fn is_public_key(&self, public_key: &PublicKey) -> Option { - let alg = Code::try_from(self.multihash.code()) - .expect("Internal multihash is always a valid `Code`"); - let enc = public_key.to_protobuf_encoding(); - Some(alg.digest(&enc) == self.multihash) - } + /// Builds a `PeerId` from a public key. + pub fn from_public_key(key: &PublicKey) -> PeerId { + let key_enc = key.to_protobuf_encoding(); + + let hash_algorithm = if key_enc.len() <= MAX_INLINE_KEY_LENGTH { + Code::Identity + } else { + Code::Sha2_256 + }; + + let multihash = hash_algorithm.digest(&key_enc); + + PeerId { multihash } + } + + /// Parses a `PeerId` from bytes. + pub fn from_bytes(data: &[u8]) -> Result { + PeerId::from_multihash(Multihash::from_bytes(data)?) + .map_err(|mh| Error::UnsupportedCode(mh.code())) + } + + /// Tries to turn a `Multihash` into a `PeerId`. + /// + /// If the multihash does not use a valid hashing algorithm for peer IDs, + /// or the hash value does not satisfy the constraints for a hashed + /// peer ID, it is returned as an `Err`. + pub fn from_multihash(multihash: Multihash) -> Result { + match Code::try_from(multihash.code()) { + Ok(Code::Sha2_256) => Ok(PeerId { multihash }), + Ok(Code::Identity) if multihash.digest().len() <= MAX_INLINE_KEY_LENGTH => + Ok(PeerId { multihash }), + _ => Err(multihash), + } + } + + /// Tries to extract a [`PeerId`] from the given [`Multiaddr`]. + /// + /// In case the given [`Multiaddr`] ends with `/p2p/`, this function + /// will return the encapsulated [`PeerId`], otherwise it will return `None`. + pub fn try_from_multiaddr(address: &Multiaddr) -> Option { + address.iter().last().and_then(|p| match p { + Protocol::P2p(hash) => PeerId::from_multihash(hash).ok(), + _ => None, + }) + } + + /// Generates a random peer ID from a cryptographically secure PRNG. + /// + /// This is useful for randomly walking on a DHT, or for testing purposes. + pub fn random() -> PeerId { + let peer_id = rand::thread_rng().gen::<[u8; 32]>(); + PeerId { + multihash: Multihash::wrap(Code::Identity.into(), &peer_id) + .expect("The digest size is never too large"), + } + } + + /// Returns a raw bytes representation of this `PeerId`. + pub fn to_bytes(&self) -> Vec { + self.multihash.to_bytes() + } + + /// Returns a base-58 encoded string of this `PeerId`. + pub fn to_base58(&self) -> String { + bs58::encode(self.to_bytes()).into_string() + } + + /// Checks whether the public key passed as parameter matches the public key of this `PeerId`. + /// + /// Returns `None` if this `PeerId`s hash algorithm is not supported when encoding the + /// given public key, otherwise `Some` boolean as the result of an equality check. + pub fn is_public_key(&self, public_key: &PublicKey) -> Option { + let alg = Code::try_from(self.multihash.code()) + .expect("Internal multihash is always a valid `Code`"); + let enc = public_key.to_protobuf_encoding(); + Some(alg.digest(&enc) == self.multihash) + } } impl From for PeerId { - fn from(key: PublicKey) -> PeerId { - PeerId::from_public_key(&key) - } + fn from(key: PublicKey) -> PeerId { + PeerId::from_public_key(&key) + } } impl From<&PublicKey> for PeerId { - fn from(key: &PublicKey) -> PeerId { - PeerId::from_public_key(key) - } + fn from(key: &PublicKey) -> PeerId { + PeerId::from_public_key(key) + } } impl TryFrom> for PeerId { - type Error = Vec; + type Error = Vec; - fn try_from(value: Vec) -> Result { - PeerId::from_bytes(&value).map_err(|_| value) - } + fn try_from(value: Vec) -> Result { + PeerId::from_bytes(&value).map_err(|_| value) + } } impl TryFrom for PeerId { - type Error = Multihash; + type Error = Multihash; - fn try_from(value: Multihash) -> Result { - PeerId::from_multihash(value) - } + fn try_from(value: Multihash) -> Result { + PeerId::from_multihash(value) + } } impl AsRef for PeerId { - fn as_ref(&self) -> &Multihash { - &self.multihash - } + fn as_ref(&self) -> &Multihash { + &self.multihash + } } impl From for Multihash { - fn from(peer_id: PeerId) -> Self { - peer_id.multihash - } + fn from(peer_id: PeerId) -> Self { + peer_id.multihash + } } impl From for Vec { - fn from(peer_id: PeerId) -> Self { - peer_id.to_bytes() - } + fn from(peer_id: PeerId) -> Self { + peer_id.to_bytes() + } } impl Serialize for PeerId { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - if serializer.is_human_readable() { - serializer.serialize_str(&self.to_base58()) - } else { - serializer.serialize_bytes(&self.to_bytes()[..]) - } - } + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + if serializer.is_human_readable() { + serializer.serialize_str(&self.to_base58()) + } else { + serializer.serialize_bytes(&self.to_bytes()[..]) + } + } } impl<'de> Deserialize<'de> for PeerId { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::*; - - struct PeerIdVisitor; - - impl<'de> Visitor<'de> for PeerIdVisitor { - type Value = PeerId; - - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "valid peer id") - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: Error, - { - PeerId::from_bytes(v).map_err(|_| Error::invalid_value(Unexpected::Bytes(v), &self)) - } - - fn visit_str(self, v: &str) -> Result - where - E: Error, - { - PeerId::from_str(v).map_err(|_| Error::invalid_value(Unexpected::Str(v), &self)) - } - } - - if deserializer.is_human_readable() { - deserializer.deserialize_str(PeerIdVisitor) - } else { - deserializer.deserialize_bytes(PeerIdVisitor) - } - } + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::*; + + struct PeerIdVisitor; + + impl<'de> Visitor<'de> for PeerIdVisitor { + type Value = PeerId; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "valid peer id") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + PeerId::from_bytes(v).map_err(|_| Error::invalid_value(Unexpected::Bytes(v), &self)) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + PeerId::from_str(v).map_err(|_| Error::invalid_value(Unexpected::Str(v), &self)) + } + } + + if deserializer.is_human_readable() { + deserializer.deserialize_str(PeerIdVisitor) + } else { + deserializer.deserialize_bytes(PeerIdVisitor) + } + } } #[derive(Debug, Error)] pub enum ParseError { - #[error("base-58 decode error: {0}")] - B58(#[from] bs58::decode::Error), - #[error("decoding multihash failed")] - MultiHash, + #[error("base-58 decode error: {0}")] + B58(#[from] bs58::decode::Error), + #[error("decoding multihash failed")] + MultiHash, } impl FromStr for PeerId { - type Err = ParseError; + type Err = ParseError; - #[inline] - fn from_str(s: &str) -> Result { - let bytes = bs58::decode(s).into_vec()?; - PeerId::from_bytes(&bytes).map_err(|_| ParseError::MultiHash) - } + #[inline] + fn from_str(s: &str) -> Result { + let bytes = bs58::decode(s).into_vec()?; + PeerId::from_bytes(&bytes).map_err(|_| ParseError::MultiHash) + } } #[cfg(test)] mod tests { - use crate::{crypto::ed25519::Keypair, PeerId}; - use multiaddr::{Multiaddr, Protocol}; - use multihash::Multihash; - - #[test] - fn peer_id_is_public_key() { - let key = Keypair::generate().public(); - let peer_id = key.to_peer_id(); - assert_eq!(peer_id.is_public_key(&key.into()), Some(true)); - } - - #[test] - fn peer_id_into_bytes_then_from_bytes() { - let peer_id = Keypair::generate().public().to_peer_id(); - let second = PeerId::from_bytes(&peer_id.to_bytes()).unwrap(); - assert_eq!(peer_id, second); - } - - #[test] - fn peer_id_to_base58_then_back() { - let peer_id = Keypair::generate().public().to_peer_id(); - let second: PeerId = peer_id.to_base58().parse().unwrap(); - assert_eq!(peer_id, second); - } - - #[test] - fn random_peer_id_is_valid() { - for _ in 0..5000 { - let peer_id = PeerId::random(); - assert_eq!(peer_id, PeerId::from_bytes(&peer_id.to_bytes()).unwrap()); - } - } - - #[test] - fn peer_id_from_multiaddr() { - let address = "[::1]:1337".parse::().unwrap(); - let peer = PeerId::random(); - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::P2p(Multihash::from(peer))); - - assert_eq!(peer, PeerId::try_from_multiaddr(&address).unwrap()); - } - - #[test] - fn peer_id_from_multiaddr_no_peer_id() { - let address = "[::1]:1337".parse::().unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())); - - assert!(PeerId::try_from_multiaddr(&address).is_none()); - } - - #[test] - fn peer_id_from_bytes() { - let peer = PeerId::random(); - let bytes = peer.to_bytes(); - - assert_eq!(PeerId::try_from(bytes).unwrap(), peer); - } - - #[test] - fn peer_id_as_multihash() { - let peer = PeerId::random(); - let multihash = Multihash::from(peer); - - assert_eq!(&multihash, peer.as_ref()); - assert_eq!(PeerId::try_from(multihash).unwrap(), peer); - } - - #[test] - fn serialize_deserialize() { - let peer = PeerId::random(); - let serialized = serde_json::to_string(&peer).unwrap(); - let deserialized = serde_json::from_str(&serialized).unwrap(); - - assert_eq!(peer, deserialized); - } - - #[test] - fn invalid_multihash() { - fn test() -> crate::Result { - let bytes = [ - 0x16, 0x20, 0x64, 0x4b, 0xcc, 0x7e, 0x56, 0x43, 0x73, 0x04, 0x09, 0x99, 0xaa, 0xc8, - 0x9e, 0x76, 0x22, 0xf3, 0xca, 0x71, 0xfb, 0xa1, 0xd9, 0x72, 0xfd, 0x94, 0xa3, 0x1c, - 0x3b, 0xfb, 0xf2, 0x4e, 0x39, 0x38, - ]; - - PeerId::from_multihash(Multihash::from_bytes(&bytes).unwrap()).map_err(From::from) - } - let _error = test().unwrap_err(); - } + use crate::{crypto::ed25519::Keypair, PeerId}; + use multiaddr::{Multiaddr, Protocol}; + use multihash::Multihash; + + #[test] + fn peer_id_is_public_key() { + let key = Keypair::generate().public(); + let peer_id = key.to_peer_id(); + assert_eq!(peer_id.is_public_key(&key.into()), Some(true)); + } + + #[test] + fn peer_id_into_bytes_then_from_bytes() { + let peer_id = Keypair::generate().public().to_peer_id(); + let second = PeerId::from_bytes(&peer_id.to_bytes()).unwrap(); + assert_eq!(peer_id, second); + } + + #[test] + fn peer_id_to_base58_then_back() { + let peer_id = Keypair::generate().public().to_peer_id(); + let second: PeerId = peer_id.to_base58().parse().unwrap(); + assert_eq!(peer_id, second); + } + + #[test] + fn random_peer_id_is_valid() { + for _ in 0..5000 { + let peer_id = PeerId::random(); + assert_eq!(peer_id, PeerId::from_bytes(&peer_id.to_bytes()).unwrap()); + } + } + + #[test] + fn peer_id_from_multiaddr() { + let address = "[::1]:1337".parse::().unwrap(); + let peer = PeerId::random(); + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::P2p(Multihash::from(peer))); + + assert_eq!(peer, PeerId::try_from_multiaddr(&address).unwrap()); + } + + #[test] + fn peer_id_from_multiaddr_no_peer_id() { + let address = "[::1]:1337".parse::().unwrap(); + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())); + + assert!(PeerId::try_from_multiaddr(&address).is_none()); + } + + #[test] + fn peer_id_from_bytes() { + let peer = PeerId::random(); + let bytes = peer.to_bytes(); + + assert_eq!(PeerId::try_from(bytes).unwrap(), peer); + } + + #[test] + fn peer_id_as_multihash() { + let peer = PeerId::random(); + let multihash = Multihash::from(peer); + + assert_eq!(&multihash, peer.as_ref()); + assert_eq!(PeerId::try_from(multihash).unwrap(), peer); + } + + #[test] + fn serialize_deserialize() { + let peer = PeerId::random(); + let serialized = serde_json::to_string(&peer).unwrap(); + let deserialized = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(peer, deserialized); + } + + #[test] + fn invalid_multihash() { + fn test() -> crate::Result { + let bytes = [ + 0x16, 0x20, 0x64, 0x4b, 0xcc, 0x7e, 0x56, 0x43, 0x73, 0x04, 0x09, 0x99, 0xaa, 0xc8, + 0x9e, 0x76, 0x22, 0xf3, 0xca, 0x71, 0xfb, 0xa1, 0xd9, 0x72, 0xfd, 0x94, 0xa3, 0x1c, + 0x3b, 0xfb, 0xf2, 0x4e, 0x39, 0x38, + ]; + + PeerId::from_multihash(Multihash::from_bytes(&bytes).unwrap()).map_err(From::from) + } + let _error = test().unwrap_err(); + } } diff --git a/src/protocol/connection.rs b/src/protocol/connection.rs index f711b7cc..6d2bdcdf 100644 --- a/src/protocol/connection.rs +++ b/src/protocol/connection.rs @@ -21,9 +21,9 @@ //! Connection-related helper code. use crate::{ - error::Error, - protocol::protocol_set::ProtocolCommand, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + error::Error, + protocol::protocol_set::ProtocolCommand, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, }; use tokio::sync::mpsc::{error::TrySendError, Sender, WeakSender}; @@ -31,209 +31,215 @@ use tokio::sync::mpsc::{error::TrySendError, Sender, WeakSender}; /// Connection type, from the point of view of the protocol. #[derive(Debug, Clone)] enum ConnectionType { - /// Connection is actively kept open. - Active(Sender), + /// Connection is actively kept open. + Active(Sender), - /// Connection is considered inactive as far as the protocol is concerned - /// and if no substreams are being opened and no protocol is interested in - /// keeping the connection open, it will be closed. - Inactive(WeakSender), + /// Connection is considered inactive as far as the protocol is concerned + /// and if no substreams are being opened and no protocol is interested in + /// keeping the connection open, it will be closed. + Inactive(WeakSender), } /// Type representing a handle to connection which allows protocols to communicate with the /// connection. #[derive(Debug, Clone)] pub struct ConnectionHandle { - /// Connection type. - connection: ConnectionType, + /// Connection type. + connection: ConnectionType, - /// Connection ID. - connection_id: ConnectionId, + /// Connection ID. + connection_id: ConnectionId, } impl ConnectionHandle { - /// Create new [`ConnectionHandle`]. - /// - /// By default the connection is set as `Active` to give protocols time to open a substream if - /// they wish. - pub fn new(connection_id: ConnectionId, connection: Sender) -> Self { - Self { connection_id, connection: ConnectionType::Active(connection) } - } - - /// Get active sender from the [`ConnectionHandle`] and then downgrade it to an inactive - /// connection. - /// - /// This function is only called once when the connection is established to remote peer and that - /// one time the connection type must be `Active`, unless there is a logic bug in `litep2p`. - pub fn downgrade(&mut self) -> Self { - let connection = match &self.connection { - ConnectionType::Active(connection) => { - let handle = Self::new(self.connection_id, connection.clone()); - self.connection = ConnectionType::Inactive(connection.downgrade()); - - handle - }, - ConnectionType::Inactive(_) => { - panic!("state mismatch: tried to downgrade an inactive connection") - }, - }; - - connection - } - - /// Get reference to connection ID. - pub fn connection_id(&self) -> &ConnectionId { - &self.connection_id - } - - /// Mark connection as closed. - pub fn close(&mut self) { - if let ConnectionType::Active(connection) = &self.connection { - self.connection = ConnectionType::Inactive(connection.downgrade()); - } - } - - /// Attempt to acquire permit which will keep the connection open for indefinite time. - pub fn try_get_permit(&self) -> Option { - match &self.connection { - ConnectionType::Active(active) => Some(Permit::new(active.clone())), - ConnectionType::Inactive(inactive) => Some(Permit::new(inactive.upgrade()?)), - } - } - - /// Open substream to remote peer over `protocol` and send the acquired permit to the - /// transport so it can be given to the opened substream. - pub fn open_substream( - &mut self, - protocol: ProtocolName, - fallback_names: Vec, - substream_id: SubstreamId, - permit: Permit, - ) -> crate::Result<()> { - match &self.connection { - ConnectionType::Active(active) => active.clone(), - ConnectionType::Inactive(inactive) => - inactive.upgrade().ok_or(Error::ConnectionClosed)?, - } - .try_send(ProtocolCommand::OpenSubstream { - protocol: protocol.clone(), - fallback_names, - substream_id, - permit, - }) - .map_err(|error| match error { - TrySendError::Full(_) => Error::ChannelClogged, - TrySendError::Closed(_) => Error::ConnectionClosed, - }) - } - - /// Force close connection. - pub fn force_close(&mut self) -> crate::Result<()> { - match &self.connection { - ConnectionType::Active(active) => active.clone(), - ConnectionType::Inactive(inactive) => - inactive.upgrade().ok_or(Error::ConnectionClosed)?, - } - .try_send(ProtocolCommand::ForceClose) - .map_err(|error| match error { - TrySendError::Full(_) => Error::ChannelClogged, - TrySendError::Closed(_) => Error::ConnectionClosed, - }) - } + /// Create new [`ConnectionHandle`]. + /// + /// By default the connection is set as `Active` to give protocols time to open a substream if + /// they wish. + pub fn new(connection_id: ConnectionId, connection: Sender) -> Self { + Self { + connection_id, + connection: ConnectionType::Active(connection), + } + } + + /// Get active sender from the [`ConnectionHandle`] and then downgrade it to an inactive + /// connection. + /// + /// This function is only called once when the connection is established to remote peer and that + /// one time the connection type must be `Active`, unless there is a logic bug in `litep2p`. + pub fn downgrade(&mut self) -> Self { + let connection = match &self.connection { + ConnectionType::Active(connection) => { + let handle = Self::new(self.connection_id, connection.clone()); + self.connection = ConnectionType::Inactive(connection.downgrade()); + + handle + } + ConnectionType::Inactive(_) => { + panic!("state mismatch: tried to downgrade an inactive connection") + } + }; + + connection + } + + /// Get reference to connection ID. + pub fn connection_id(&self) -> &ConnectionId { + &self.connection_id + } + + /// Mark connection as closed. + pub fn close(&mut self) { + if let ConnectionType::Active(connection) = &self.connection { + self.connection = ConnectionType::Inactive(connection.downgrade()); + } + } + + /// Attempt to acquire permit which will keep the connection open for indefinite time. + pub fn try_get_permit(&self) -> Option { + match &self.connection { + ConnectionType::Active(active) => Some(Permit::new(active.clone())), + ConnectionType::Inactive(inactive) => Some(Permit::new(inactive.upgrade()?)), + } + } + + /// Open substream to remote peer over `protocol` and send the acquired permit to the + /// transport so it can be given to the opened substream. + pub fn open_substream( + &mut self, + protocol: ProtocolName, + fallback_names: Vec, + substream_id: SubstreamId, + permit: Permit, + ) -> crate::Result<()> { + match &self.connection { + ConnectionType::Active(active) => active.clone(), + ConnectionType::Inactive(inactive) => + inactive.upgrade().ok_or(Error::ConnectionClosed)?, + } + .try_send(ProtocolCommand::OpenSubstream { + protocol: protocol.clone(), + fallback_names, + substream_id, + permit, + }) + .map_err(|error| match error { + TrySendError::Full(_) => Error::ChannelClogged, + TrySendError::Closed(_) => Error::ConnectionClosed, + }) + } + + /// Force close connection. + pub fn force_close(&mut self) -> crate::Result<()> { + match &self.connection { + ConnectionType::Active(active) => active.clone(), + ConnectionType::Inactive(inactive) => + inactive.upgrade().ok_or(Error::ConnectionClosed)?, + } + .try_send(ProtocolCommand::ForceClose) + .map_err(|error| match error { + TrySendError::Full(_) => Error::ChannelClogged, + TrySendError::Closed(_) => Error::ConnectionClosed, + }) + } } /// Type which allows the connection to be kept open. #[derive(Debug)] pub struct Permit { - /// Active connection. - _connection: Sender, + /// Active connection. + _connection: Sender, } impl Permit { - /// Create new [`Permit`] which allows the connection to be kept open. - pub fn new(_connection: Sender) -> Self { - Self { _connection } - } + /// Create new [`Permit`] which allows the connection to be kept open. + pub fn new(_connection: Sender) -> Self { + Self { _connection } + } } #[cfg(test)] mod tests { - use super::*; - use tokio::sync::mpsc::channel; - - #[test] - #[should_panic] - fn downgrade_inactive_connection() { - let (tx, _rx) = channel(1); - let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); - - let mut new_handle = handle.downgrade(); - assert!(std::matches!(new_handle.connection, ConnectionType::Inactive(_))); - - // try to downgrade an already-downgraded connection - let _handle = new_handle.downgrade(); - } - - #[tokio::test] - async fn open_substream_open_downgraded_connection() { - let (tx, mut rx) = channel(1); - let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); - let mut handle = handle.downgrade(); - let permit = handle.try_get_permit().unwrap(); - - let result = handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), - SubstreamId::new(), - permit, - ); - - assert!(result.is_ok()); - assert!(rx.recv().await.is_some()); - } - - #[tokio::test] - async fn open_substream_closed_downgraded_connection() { - let (tx, _rx) = channel(1); - let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); - let mut handle = handle.downgrade(); - let permit = handle.try_get_permit().unwrap(); - drop(_rx); - - let result = handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), - SubstreamId::new(), - permit, - ); - - assert!(result.is_err()); - } - - #[tokio::test] - async fn open_substream_channel_clogged() { - let (tx, _rx) = channel(1); - let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); - let mut handle = handle.downgrade(); - let permit = handle.try_get_permit().unwrap(); - - let result = handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), - SubstreamId::new(), - permit, - ); - assert!(result.is_ok()); - - let permit = handle.try_get_permit().unwrap(); - match handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), - SubstreamId::new(), - permit, - ) { - Err(Error::ChannelClogged) => {}, - error => panic!("invalid error: {error:?}"), - } - } + use super::*; + use tokio::sync::mpsc::channel; + + #[test] + #[should_panic] + fn downgrade_inactive_connection() { + let (tx, _rx) = channel(1); + let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); + + let mut new_handle = handle.downgrade(); + assert!(std::matches!( + new_handle.connection, + ConnectionType::Inactive(_) + )); + + // try to downgrade an already-downgraded connection + let _handle = new_handle.downgrade(); + } + + #[tokio::test] + async fn open_substream_open_downgraded_connection() { + let (tx, mut rx) = channel(1); + let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); + let mut handle = handle.downgrade(); + let permit = handle.try_get_permit().unwrap(); + + let result = handle.open_substream( + ProtocolName::from("/protocol/1"), + Vec::new(), + SubstreamId::new(), + permit, + ); + + assert!(result.is_ok()); + assert!(rx.recv().await.is_some()); + } + + #[tokio::test] + async fn open_substream_closed_downgraded_connection() { + let (tx, _rx) = channel(1); + let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); + let mut handle = handle.downgrade(); + let permit = handle.try_get_permit().unwrap(); + drop(_rx); + + let result = handle.open_substream( + ProtocolName::from("/protocol/1"), + Vec::new(), + SubstreamId::new(), + permit, + ); + + assert!(result.is_err()); + } + + #[tokio::test] + async fn open_substream_channel_clogged() { + let (tx, _rx) = channel(1); + let mut handle = ConnectionHandle::new(ConnectionId::new(), tx); + let mut handle = handle.downgrade(); + let permit = handle.try_get_permit().unwrap(); + + let result = handle.open_substream( + ProtocolName::from("/protocol/1"), + Vec::new(), + SubstreamId::new(), + permit, + ); + assert!(result.is_ok()); + + let permit = handle.try_get_permit().unwrap(); + match handle.open_substream( + ProtocolName::from("/protocol/1"), + Vec::new(), + SubstreamId::new(), + permit, + ) { + Err(Error::ChannelClogged) => {} + error => panic!("invalid error: {error:?}"), + } + } } diff --git a/src/protocol/libp2p/bitswap/config.rs b/src/protocol/libp2p/bitswap/config.rs index a524e97f..55e86e4a 100644 --- a/src/protocol/libp2p/bitswap/config.rs +++ b/src/protocol/libp2p/bitswap/config.rs @@ -19,10 +19,10 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, - protocol::libp2p::bitswap::{BitswapCommand, BitswapEvent, BitswapHandle}, - types::protocol::ProtocolName, - DEFAULT_CHANNEL_SIZE, + codec::ProtocolCodec, + protocol::libp2p::bitswap::{BitswapCommand, BitswapEvent, BitswapHandle}, + types::protocol::ProtocolName, + DEFAULT_CHANNEL_SIZE, }; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -36,33 +36,33 @@ const MAX_PAYLOAD_SIZE: usize = 2_097_152; /// Bitswap configuration. #[derive(Debug)] pub struct Config { - /// Protocol name. - pub(crate) protocol: ProtocolName, + /// Protocol name. + pub(crate) protocol: ProtocolName, - /// Protocol codec. - pub(crate) codec: ProtocolCodec, + /// Protocol codec. + pub(crate) codec: ProtocolCodec, - /// TX channel for sending events to the user protocol. - pub(super) event_tx: Sender, + /// TX channel for sending events to the user protocol. + pub(super) event_tx: Sender, - /// RX channel for receiving commands from the user. - pub(super) cmd_rx: Receiver, + /// RX channel for receiving commands from the user. + pub(super) cmd_rx: Receiver, } impl Config { - /// Create new [`Config`]. - pub fn new() -> (Self, BitswapHandle) { - let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); - let (cmd_tx, cmd_rx) = channel(DEFAULT_CHANNEL_SIZE); + /// Create new [`Config`]. + pub fn new() -> (Self, BitswapHandle) { + let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); + let (cmd_tx, cmd_rx) = channel(DEFAULT_CHANNEL_SIZE); - ( - Self { - cmd_rx, - event_tx, - protocol: ProtocolName::from(PROTOCOL_NAME), - codec: ProtocolCodec::UnsignedVarint(Some(MAX_PAYLOAD_SIZE)), - }, - BitswapHandle::new(event_rx, cmd_tx), - ) - } + ( + Self { + cmd_rx, + event_tx, + protocol: ProtocolName::from(PROTOCOL_NAME), + codec: ProtocolCodec::UnsignedVarint(Some(MAX_PAYLOAD_SIZE)), + }, + BitswapHandle::new(event_rx, cmd_tx), + ) + } } diff --git a/src/protocol/libp2p/bitswap/handle.rs b/src/protocol/libp2p/bitswap/handle.rs index 911968ad..4841582e 100644 --- a/src/protocol/libp2p/bitswap/handle.rs +++ b/src/protocol/libp2p/bitswap/handle.rs @@ -21,98 +21,98 @@ //! Bitswap handle for communicating with the bitswap protocol implementation. use crate::{ - protocol::libp2p::bitswap::{BlockPresenceType, WantType}, - PeerId, + protocol::libp2p::bitswap::{BlockPresenceType, WantType}, + PeerId, }; use cid::Cid; use tokio::sync::mpsc::{Receiver, Sender}; use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; /// Events emitted by the bitswap protocol. #[derive(Debug)] pub enum BitswapEvent { - /// Bitswap request. - Request { - /// Peer ID. - peer: PeerId, - - /// Requested CIDs. - cids: Vec<(Cid, WantType)>, - }, + /// Bitswap request. + Request { + /// Peer ID. + peer: PeerId, + + /// Requested CIDs. + cids: Vec<(Cid, WantType)>, + }, } /// Response type for received bitswap request. #[derive(Debug)] pub enum ResponseType { - /// Block. - Block { - /// CID. - cid: Cid, - - /// Found block. - block: Vec, - }, - - /// Presense. - Presence { - /// CID. - cid: Cid, - - /// Whether the requested block exists or not. - presence: BlockPresenceType, - }, + /// Block. + Block { + /// CID. + cid: Cid, + + /// Found block. + block: Vec, + }, + + /// Presense. + Presence { + /// CID. + cid: Cid, + + /// Whether the requested block exists or not. + presence: BlockPresenceType, + }, } /// Commands sent from the user to `Bitswap`. #[derive(Debug)] pub(super) enum BitswapCommand { - /// Send bitswap response. - SendResponse { - /// Peer ID. - peer: PeerId, - - /// CIDs. - responses: Vec, - }, + /// Send bitswap response. + SendResponse { + /// Peer ID. + peer: PeerId, + + /// CIDs. + responses: Vec, + }, } /// Handle for communicating with the bitswap protocol. pub struct BitswapHandle { - /// RX channel for receiving bitswap events. - event_rx: Receiver, + /// RX channel for receiving bitswap events. + event_rx: Receiver, - /// TX channel for sending commads to `Bitswap`. - cmd_tx: Sender, + /// TX channel for sending commads to `Bitswap`. + cmd_tx: Sender, } impl BitswapHandle { - /// Create new [`BitswapHandle`]. - pub(super) fn new(event_rx: Receiver, cmd_tx: Sender) -> Self { - Self { event_rx, cmd_tx } - } - - /// Send `request` to `peer`. - /// - /// Not supported by the current implementation. - pub async fn send_request(&self, _peer: PeerId, _request: Vec) { - unimplemented!("bitswap requests are not supported"); - } - - /// Send `response` to `peer`. - pub async fn send_response(&self, peer: PeerId, responses: Vec) { - let _ = self.cmd_tx.send(BitswapCommand::SendResponse { peer, responses }).await; - } + /// Create new [`BitswapHandle`]. + pub(super) fn new(event_rx: Receiver, cmd_tx: Sender) -> Self { + Self { event_rx, cmd_tx } + } + + /// Send `request` to `peer`. + /// + /// Not supported by the current implementation. + pub async fn send_request(&self, _peer: PeerId, _request: Vec) { + unimplemented!("bitswap requests are not supported"); + } + + /// Send `response` to `peer`. + pub async fn send_response(&self, peer: PeerId, responses: Vec) { + let _ = self.cmd_tx.send(BitswapCommand::SendResponse { peer, responses }).await; + } } impl futures::Stream for BitswapHandle { - type Item = BitswapEvent; + type Item = BitswapEvent; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.event_rx).poll_recv(cx) - } + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.event_rx).poll_recv(cx) + } } diff --git a/src/protocol/libp2p/bitswap/mod.rs b/src/protocol/libp2p/bitswap/mod.rs index 221f6d3f..d82ea891 100644 --- a/src/protocol/libp2p/bitswap/mod.rs +++ b/src/protocol/libp2p/bitswap/mod.rs @@ -21,13 +21,13 @@ //! [`/ipfs/bitswap/1.2.0`](https://github.com/ipfs/specs/blob/main/BITSWAP.md) implementation. use crate::{ - error::Error, - protocol::{ - libp2p::bitswap::handle::BitswapCommand, Direction, TransportEvent, TransportService, - }, - substream::Substream, - types::SubstreamId, - PeerId, + error::Error, + protocol::{ + libp2p::bitswap::handle::BitswapCommand, Direction, TransportEvent, TransportService, + }, + substream::Substream, + types::SubstreamId, + PeerId, }; use cid::{multihash::Code, Version}; @@ -46,9 +46,9 @@ mod config; mod handle; mod schema { - pub(super) mod bitswap { - include!(concat!(env!("OUT_DIR"), "/bitswap.rs")); - } + pub(super) mod bitswap { + include!(concat!(env!("OUT_DIR"), "/bitswap.rs")); + } } /// Log target for the file. @@ -57,194 +57,197 @@ const LOG_TARGET: &str = "litep2p::ipfs::bitswap"; /// Bitswap metadata. #[derive(Debug)] struct Prefix { - /// CID version. - version: Version, + /// CID version. + version: Version, - /// CID codec. - codec: u64, + /// CID codec. + codec: u64, - /// CID multihash type. - multihash_type: u64, + /// CID multihash type. + multihash_type: u64, - /// CID multihash length. - multihash_len: u8, + /// CID multihash length. + multihash_len: u8, } impl Prefix { - /// Convert the prefix to encoded bytes. - pub fn to_bytes(&self) -> Vec { - let mut res = Vec::with_capacity(4 * 10); - - let mut buf = unsigned_varint::encode::u64_buffer(); - let version = unsigned_varint::encode::u64(self.version.into(), &mut buf); - res.extend_from_slice(version); - - let mut buf = unsigned_varint::encode::u64_buffer(); - let codec = unsigned_varint::encode::u64(self.codec, &mut buf); - res.extend_from_slice(codec); - - let mut buf = unsigned_varint::encode::u64_buffer(); - let multihash_type = unsigned_varint::encode::u64(self.multihash_type, &mut buf); - res.extend_from_slice(multihash_type); - - let mut buf = unsigned_varint::encode::u64_buffer(); - let multihash_len = unsigned_varint::encode::u64(self.multihash_len as u64, &mut buf); - res.extend_from_slice(multihash_len); - res - } + /// Convert the prefix to encoded bytes. + pub fn to_bytes(&self) -> Vec { + let mut res = Vec::with_capacity(4 * 10); + + let mut buf = unsigned_varint::encode::u64_buffer(); + let version = unsigned_varint::encode::u64(self.version.into(), &mut buf); + res.extend_from_slice(version); + + let mut buf = unsigned_varint::encode::u64_buffer(); + let codec = unsigned_varint::encode::u64(self.codec, &mut buf); + res.extend_from_slice(codec); + + let mut buf = unsigned_varint::encode::u64_buffer(); + let multihash_type = unsigned_varint::encode::u64(self.multihash_type, &mut buf); + res.extend_from_slice(multihash_type); + + let mut buf = unsigned_varint::encode::u64_buffer(); + let multihash_len = unsigned_varint::encode::u64(self.multihash_len as u64, &mut buf); + res.extend_from_slice(multihash_len); + res + } } /// Bitswap protocol. pub(crate) struct Bitswap { - // Connection service. - service: TransportService, + // Connection service. + service: TransportService, - /// TX channel for sending events to the user protocol. - event_tx: Sender, + /// TX channel for sending events to the user protocol. + event_tx: Sender, - /// RX channel for receiving commands from `BitswapHandle`. - cmd_rx: Receiver, + /// RX channel for receiving commands from `BitswapHandle`. + cmd_rx: Receiver, - /// Pending outbound substreams. - pending_outbound: HashMap>, + /// Pending outbound substreams. + pending_outbound: HashMap>, - /// Pending inbound substreams. - pending_inbound: - FuturesUnordered)>>>, + /// Pending inbound substreams. + pending_inbound: + FuturesUnordered)>>>, } impl Bitswap { - /// Create new [`Bitswap`] protocol. - pub(crate) fn new(service: TransportService, config: Config) -> Self { - Self { - service, - cmd_rx: config.cmd_rx, - event_tx: config.event_tx, - pending_outbound: HashMap::new(), - pending_inbound: FuturesUnordered::new(), - } - } - - /// Substream opened to remote peer. - fn on_inbound_substream(&mut self, peer: PeerId, mut substream: Substream) { - tracing::debug!(target: LOG_TARGET, ?peer, "handle inbound substream"); - - self.pending_inbound.push(Box::pin(async move { - let message = substream.next().await.ok_or(Error::ConnectionClosed)??; - let message = schema::bitswap::Message::decode(message)?; - - let Some(wantlist) = message.wantlist else { - tracing::debug!(target: LOG_TARGET, "bitswap message doesn't contain `WantList`"); - return Err(Error::InvalidData); - }; - - Ok(( - peer, - wantlist - .entries - .into_iter() - .filter_map(|entry| { - let cid = Cid::read_bytes(entry.block.as_slice()).ok()?; - - let want_type = match entry.want_type { - 0 => WantType::Block, - 1 => WantType::Have, - _ => return None, - }; - - (cid.version() == cid::Version::V1 && - cid.hash().code() == u64::from(Code::Blake2b256) && - cid.hash().size() == 32) - .then_some((cid, want_type)) - }) - .collect::>(), - )) - })); - } - - /// Send response to bitswap request. - async fn on_outbound_substream( - &mut self, - peer: PeerId, - substream_id: SubstreamId, - mut substream: Substream, - ) { - let Some(entries) = self.pending_outbound.remove(&substream_id) else { - tracing::warn!(target: LOG_TARGET, ?peer, ?substream_id, "pending outbound entry doesn't exist"); - return; - }; - - let mut response = schema::bitswap::Message::default(); - - for entry in entries { - match entry { - ResponseType::Block { cid, block } => { - let prefix = Prefix { - version: cid.version(), - codec: cid.codec(), - multihash_type: cid.hash().code(), - multihash_len: cid.hash().size(), - } - .to_bytes(); - - response.payload.push(schema::bitswap::Block { prefix, data: block }); - }, - ResponseType::Presence { cid, presence } => { - response.block_presences.push(schema::bitswap::BlockPresence { - cid: cid.to_bytes(), - r#type: presence as i32, - }); - }, - } - } - - let _ = substream.send_framed(response.encode_to_vec().into()).await; - } - - /// Handle bitswap response. - fn on_bitswap_response(&mut self, peer: PeerId, responses: Vec) { - match self.service.open_substream(peer) { - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to open substream to peer") - }, - Ok(substream_id) => { - self.pending_outbound.insert(substream_id, responses); - }, - } - } - - /// Start [`Bitswap`] event loop. - pub async fn run(mut self) { - tracing::debug!(target: LOG_TARGET, "starting bitswap event loop"); - - loop { - tokio::select! { - event = self.service.next() => match event { - Some(TransportEvent::SubstreamOpened { - peer, - substream, - direction, - .. - }) => match direction { - Direction::Inbound => self.on_inbound_substream(peer, substream), - Direction::Outbound(substream_id) => - self.on_outbound_substream(peer, substream_id, substream).await, - }, - None => return, - event => tracing::trace!(target: LOG_TARGET, ?event, "unhandled event"), - }, - command = self.cmd_rx.recv() => match command { - Some(BitswapCommand::SendResponse { peer, responses }) => { - self.on_bitswap_response(peer, responses); - } - None => return, - }, - event = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => { - if let Some(Ok((peer, cids))) = event { - let _ = self.event_tx.send(BitswapEvent::Request { peer, cids }).await; - } - } - } - } - } + /// Create new [`Bitswap`] protocol. + pub(crate) fn new(service: TransportService, config: Config) -> Self { + Self { + service, + cmd_rx: config.cmd_rx, + event_tx: config.event_tx, + pending_outbound: HashMap::new(), + pending_inbound: FuturesUnordered::new(), + } + } + + /// Substream opened to remote peer. + fn on_inbound_substream(&mut self, peer: PeerId, mut substream: Substream) { + tracing::debug!(target: LOG_TARGET, ?peer, "handle inbound substream"); + + self.pending_inbound.push(Box::pin(async move { + let message = substream.next().await.ok_or(Error::ConnectionClosed)??; + let message = schema::bitswap::Message::decode(message)?; + + let Some(wantlist) = message.wantlist else { + tracing::debug!(target: LOG_TARGET, "bitswap message doesn't contain `WantList`"); + return Err(Error::InvalidData); + }; + + Ok(( + peer, + wantlist + .entries + .into_iter() + .filter_map(|entry| { + let cid = Cid::read_bytes(entry.block.as_slice()).ok()?; + + let want_type = match entry.want_type { + 0 => WantType::Block, + 1 => WantType::Have, + _ => return None, + }; + + (cid.version() == cid::Version::V1 + && cid.hash().code() == u64::from(Code::Blake2b256) + && cid.hash().size() == 32) + .then_some((cid, want_type)) + }) + .collect::>(), + )) + })); + } + + /// Send response to bitswap request. + async fn on_outbound_substream( + &mut self, + peer: PeerId, + substream_id: SubstreamId, + mut substream: Substream, + ) { + let Some(entries) = self.pending_outbound.remove(&substream_id) else { + tracing::warn!(target: LOG_TARGET, ?peer, ?substream_id, "pending outbound entry doesn't exist"); + return; + }; + + let mut response = schema::bitswap::Message::default(); + + for entry in entries { + match entry { + ResponseType::Block { cid, block } => { + let prefix = Prefix { + version: cid.version(), + codec: cid.codec(), + multihash_type: cid.hash().code(), + multihash_len: cid.hash().size(), + } + .to_bytes(); + + response.payload.push(schema::bitswap::Block { + prefix, + data: block, + }); + } + ResponseType::Presence { cid, presence } => { + response.block_presences.push(schema::bitswap::BlockPresence { + cid: cid.to_bytes(), + r#type: presence as i32, + }); + } + } + } + + let _ = substream.send_framed(response.encode_to_vec().into()).await; + } + + /// Handle bitswap response. + fn on_bitswap_response(&mut self, peer: PeerId, responses: Vec) { + match self.service.open_substream(peer) { + Err(error) => { + tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to open substream to peer") + } + Ok(substream_id) => { + self.pending_outbound.insert(substream_id, responses); + } + } + } + + /// Start [`Bitswap`] event loop. + pub async fn run(mut self) { + tracing::debug!(target: LOG_TARGET, "starting bitswap event loop"); + + loop { + tokio::select! { + event = self.service.next() => match event { + Some(TransportEvent::SubstreamOpened { + peer, + substream, + direction, + .. + }) => match direction { + Direction::Inbound => self.on_inbound_substream(peer, substream), + Direction::Outbound(substream_id) => + self.on_outbound_substream(peer, substream_id, substream).await, + }, + None => return, + event => tracing::trace!(target: LOG_TARGET, ?event, "unhandled event"), + }, + command = self.cmd_rx.recv() => match command { + Some(BitswapCommand::SendResponse { peer, responses }) => { + self.on_bitswap_response(peer, responses); + } + None => return, + }, + event = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => { + if let Some(Ok((peer, cids))) = event { + let _ = self.event_tx.send(BitswapEvent::Request { peer, cids }).await; + } + } + } + } + } } diff --git a/src/protocol/libp2p/identify.rs b/src/protocol/libp2p/identify.rs index 986efe7d..dc02ed62 100644 --- a/src/protocol/libp2p/identify.rs +++ b/src/protocol/libp2p/identify.rs @@ -21,14 +21,14 @@ //! [`/ipfs/identify/1.0.0`](https://github.com/libp2p/specs/blob/master/identify/README.md) implementation. use crate::{ - codec::ProtocolCodec, - crypto::PublicKey, - error::{Error, SubstreamError}, - protocol::{Direction, TransportEvent, TransportService}, - substream::Substream, - transport::Endpoint, - types::{protocol::ProtocolName, SubstreamId}, - PeerId, DEFAULT_CHANNEL_SIZE, + codec::ProtocolCodec, + crypto::PublicKey, + error::{Error, SubstreamError}, + protocol::{Direction, TransportEvent, TransportService}, + substream::Substream, + transport::Endpoint, + types::{protocol::ProtocolName, SubstreamId}, + PeerId, DEFAULT_CHANNEL_SIZE, }; use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; @@ -38,8 +38,8 @@ use tokio::sync::mpsc::{channel, Sender}; use tokio_stream::wrappers::ReceiverStream; use std::{ - collections::{HashMap, HashSet}, - time::Duration, + collections::{HashMap, HashSet}, + time::Duration, }; /// Log target for the file. @@ -59,361 +59,361 @@ const DEFAULT_AGENT: &str = "litep2p/1.0.0"; const IDENTIFY_PAYLOAD_SIZE: usize = 4096; mod identify_schema { - include!(concat!(env!("OUT_DIR"), "/identify.rs")); + include!(concat!(env!("OUT_DIR"), "/identify.rs")); } /// Identify configuration. pub struct Config { - /// Protocol name. - pub(crate) protocol: ProtocolName, + /// Protocol name. + pub(crate) protocol: ProtocolName, - /// Codec used by the protocol. - pub(crate) codec: ProtocolCodec, + /// Codec used by the protocol. + pub(crate) codec: ProtocolCodec, - /// TX channel for sending events to the user protocol. - tx_event: Sender, + /// TX channel for sending events to the user protocol. + tx_event: Sender, - // Public key of the local node, filled by `Litep2p`. - pub(crate) public: Option, + // Public key of the local node, filled by `Litep2p`. + pub(crate) public: Option, - /// Protocols supported by the local node, filled by `Litep2p`. - pub(crate) protocols: Vec, + /// Protocols supported by the local node, filled by `Litep2p`. + pub(crate) protocols: Vec, - /// Public addresses. - pub(crate) public_addresses: Vec, + /// Public addresses. + pub(crate) public_addresses: Vec, - /// Protocol version. - pub(crate) protocol_version: String, + /// Protocol version. + pub(crate) protocol_version: String, - /// User agent. - pub(crate) user_agent: Option, + /// User agent. + pub(crate) user_agent: Option, } impl Config { - /// Create new [`Config`]. - /// - /// Returns a config that is given to `Litep2pConfig` and an event stream for - /// [`IdentifyEvent`]s. - pub fn new( - protocol_version: String, - user_agent: Option, - public_addresses: Vec, - ) -> (Self, Box + Send + Unpin>) { - let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); - - ( - Self { - tx_event, - public: None, - public_addresses, - protocol_version, - user_agent, - codec: ProtocolCodec::UnsignedVarint(Some(IDENTIFY_PAYLOAD_SIZE)), - protocols: Vec::new(), - protocol: ProtocolName::from(PROTOCOL_NAME), - }, - Box::new(ReceiverStream::new(rx_event)), - ) - } + /// Create new [`Config`]. + /// + /// Returns a config that is given to `Litep2pConfig` and an event stream for + /// [`IdentifyEvent`]s. + pub fn new( + protocol_version: String, + user_agent: Option, + public_addresses: Vec, + ) -> (Self, Box + Send + Unpin>) { + let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); + + ( + Self { + tx_event, + public: None, + public_addresses, + protocol_version, + user_agent, + codec: ProtocolCodec::UnsignedVarint(Some(IDENTIFY_PAYLOAD_SIZE)), + protocols: Vec::new(), + protocol: ProtocolName::from(PROTOCOL_NAME), + }, + Box::new(ReceiverStream::new(rx_event)), + ) + } } /// Events emitted by Identify protocol. #[derive(Debug)] pub enum IdentifyEvent { - /// Peer identified. - PeerIdentified { - /// Peer ID. - peer: PeerId, + /// Peer identified. + PeerIdentified { + /// Peer ID. + peer: PeerId, - /// Protocol version. - protocol_version: Option, + /// Protocol version. + protocol_version: Option, - /// User agent. - user_agent: Option, + /// User agent. + user_agent: Option, - /// Supported protocols. - supported_protocols: HashSet, + /// Supported protocols. + supported_protocols: HashSet, - /// Observed address. - observed_address: Multiaddr, + /// Observed address. + observed_address: Multiaddr, - /// Listen addresses. - listen_addresses: Vec, - }, + /// Listen addresses. + listen_addresses: Vec, + }, } /// Identify response received from remote. struct IdentifyResponse { - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Protocol version. - protocol_version: Option, + /// Protocol version. + protocol_version: Option, - /// User agent. - user_agent: Option, + /// User agent. + user_agent: Option, - /// Protocols supported by remote. - supported_protocols: HashSet, + /// Protocols supported by remote. + supported_protocols: HashSet, - /// Remote's listen addresses. - listen_addresses: Vec, + /// Remote's listen addresses. + listen_addresses: Vec, - /// Observed address. - observed_address: Option, + /// Observed address. + observed_address: Option, } pub(crate) struct Identify { - // Connection service. - service: TransportService, + // Connection service. + service: TransportService, - /// TX channel for sending events to the user protocol. - tx: Sender, + /// TX channel for sending events to the user protocol. + tx: Sender, - /// Connected peers and their observed addresses. - peers: HashMap, + /// Connected peers and their observed addresses. + peers: HashMap, - // Public key of the local node, filled by `Litep2p`. - public: PublicKey, + // Public key of the local node, filled by `Litep2p`. + public: PublicKey, - /// Protocol version. - protocol_version: String, + /// Protocol version. + protocol_version: String, - /// User agent. - user_agent: String, + /// User agent. + user_agent: String, - /// Public addresses. - listen_addresses: HashSet, + /// Public addresses. + listen_addresses: HashSet, - /// Protocols supported by the local node, filled by `Litep2p`. - protocols: Vec, + /// Protocols supported by the local node, filled by `Litep2p`. + protocols: Vec, - /// Pending outbound substreams. - pending_opens: HashMap, + /// Pending outbound substreams. + pending_opens: HashMap, - /// Pending outbound substreams. - pending_outbound: FuturesUnordered>>, + /// Pending outbound substreams. + pending_outbound: FuturesUnordered>>, - /// Pending inbound substreams. - pending_inbound: FuturesUnordered>, + /// Pending inbound substreams. + pending_inbound: FuturesUnordered>, } impl Identify { - /// Create new [`Identify`] protocol. - pub(crate) fn new( - service: TransportService, - config: Config, - listen_addresses: Vec, - ) -> Self { - Self { - service, - tx: config.tx_event, - peers: HashMap::new(), - listen_addresses: config - .public_addresses - .into_iter() - .chain(listen_addresses.into_iter()) - .collect(), - public: config.public.expect("public key to be supplied"), - protocol_version: config.protocol_version, - user_agent: config.user_agent.unwrap_or(DEFAULT_AGENT.to_string()), - pending_opens: HashMap::new(), - pending_inbound: FuturesUnordered::new(), - pending_outbound: FuturesUnordered::new(), - protocols: config.protocols.iter().map(|protocol| protocol.to_string()).collect(), - } - } - - /// Connection established to remote peer. - fn on_connection_established(&mut self, peer: PeerId, endpoint: Endpoint) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, ?endpoint, "connection established"); - - let substream_id = self.service.open_substream(peer)?; - self.pending_opens.insert(substream_id, peer); - self.peers.insert(peer, endpoint); - - Ok(()) - } - - /// Connection closed to remote peer. - fn on_connection_closed(&mut self, peer: PeerId) { - tracing::trace!(target: LOG_TARGET, ?peer, "connection closed"); - - self.peers.remove(&peer); - } - - /// Inbound substream opened. - fn on_inbound_substream( - &mut self, - peer: PeerId, - protocol: ProtocolName, - mut substream: Substream, - ) { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?protocol, - "inbound substream opened" - ); - - let observed_addr = match self.peers.get(&peer) { - Some(endpoint) => Some(endpoint.address().to_vec()), - None => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - %protocol, - "inbound identify substream opened for peer who doesn't exist", - ); - None - }, - }; - - let identify = identify_schema::Identify { - protocol_version: Some(self.protocol_version.clone()), - agent_version: Some(self.user_agent.clone()), - public_key: Some(self.public.to_protobuf_encoding()), - listen_addrs: self - .listen_addresses - .iter() - .map(|address| address.to_vec()) - .collect::>(), - observed_addr, - protocols: self.protocols.clone(), - }; - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?identify, - "sending identify response", - ); - - let mut msg = Vec::with_capacity(identify.encoded_len()); - identify.encode(&mut msg).expect("`msg` to have enough capacity"); - - self.pending_inbound.push(Box::pin(async move { - match tokio::time::timeout(Duration::from_secs(10), substream.send_framed(msg.into())) - .await - { - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "timed out while sending ipfs identify response", - ); - }, - Ok(Err(error)) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "failed to send ipfs identify response", - ); - }, - Ok(_) => {}, - } - })) - } - - /// Outbound substream opened. - fn on_outbound_substream( - &mut self, - peer: PeerId, - protocol: ProtocolName, - substream_id: SubstreamId, - mut substream: Substream, - ) { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?protocol, - ?substream_id, - "outbound substream opened" - ); - - self.pending_outbound.push(Box::pin(async move { - let payload = - match tokio::time::timeout(Duration::from_secs(10), substream.next()).await { - Err(_) => return Err(Error::Timeout), - Ok(None) => - return Err(Error::SubstreamError(SubstreamError::ReadFailure(Some( - substream_id, - )))), - Ok(Some(Err(error))) => return Err(error), - Ok(Some(Ok(payload))) => payload, - }; - - let info = identify_schema::Identify::decode(payload.to_vec().as_slice())?; - - tracing::trace!(target: LOG_TARGET, ?peer, ?info, "peer identified"); - - let listen_addresses = info - .listen_addrs - .iter() - .filter_map(|address| Multiaddr::try_from(address.clone()).ok()) - .collect(); - let observed_address = - info.observed_addr.map(|address| Multiaddr::try_from(address).ok()).flatten(); - let protocol_version = info.protocol_version; - let user_agent = info.agent_version; - - Ok(IdentifyResponse { - peer, - protocol_version, - user_agent, - supported_protocols: HashSet::from_iter(info.protocols), - observed_address, - listen_addresses, - }) - })); - } - - /// Start [`Identify`] event loop. - pub async fn run(mut self) { - tracing::debug!(target: LOG_TARGET, "starting identify event loop"); - - loop { - tokio::select! { - event = self.service.next() => match event { - None => return, - Some(TransportEvent::ConnectionEstablished { peer, endpoint }) => { - let _ = self.on_connection_established(peer, endpoint); - } - Some(TransportEvent::ConnectionClosed { peer }) => { - self.on_connection_closed(peer); - } - Some(TransportEvent::SubstreamOpened { - peer, - protocol, - direction, - substream, - .. - }) => match direction { - Direction::Inbound => self.on_inbound_substream(peer, protocol, substream), - Direction::Outbound(substream_id) => self.on_outbound_substream(peer, protocol, substream_id, substream), - }, - _ => {} - }, - _ = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => {} - event = self.pending_outbound.next(), if !self.pending_outbound.is_empty() => match event { - Some(Ok(response)) => { - let _ = self.tx - .send(IdentifyEvent::PeerIdentified { - peer: response.peer, - protocol_version: response.protocol_version, - user_agent: response.user_agent, - supported_protocols: response.supported_protocols.into_iter().map(From::from).collect(), - observed_address: response.observed_address.map_or(Multiaddr::empty(), |address| address), - listen_addresses: response.listen_addresses, - }) - .await; - } - Some(Err(error)) => tracing::debug!(target: LOG_TARGET, ?error, "failed to read ipfs identify response"), - None => return, - } - } - } - } + /// Create new [`Identify`] protocol. + pub(crate) fn new( + service: TransportService, + config: Config, + listen_addresses: Vec, + ) -> Self { + Self { + service, + tx: config.tx_event, + peers: HashMap::new(), + listen_addresses: config + .public_addresses + .into_iter() + .chain(listen_addresses.into_iter()) + .collect(), + public: config.public.expect("public key to be supplied"), + protocol_version: config.protocol_version, + user_agent: config.user_agent.unwrap_or(DEFAULT_AGENT.to_string()), + pending_opens: HashMap::new(), + pending_inbound: FuturesUnordered::new(), + pending_outbound: FuturesUnordered::new(), + protocols: config.protocols.iter().map(|protocol| protocol.to_string()).collect(), + } + } + + /// Connection established to remote peer. + fn on_connection_established(&mut self, peer: PeerId, endpoint: Endpoint) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, ?endpoint, "connection established"); + + let substream_id = self.service.open_substream(peer)?; + self.pending_opens.insert(substream_id, peer); + self.peers.insert(peer, endpoint); + + Ok(()) + } + + /// Connection closed to remote peer. + fn on_connection_closed(&mut self, peer: PeerId) { + tracing::trace!(target: LOG_TARGET, ?peer, "connection closed"); + + self.peers.remove(&peer); + } + + /// Inbound substream opened. + fn on_inbound_substream( + &mut self, + peer: PeerId, + protocol: ProtocolName, + mut substream: Substream, + ) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?protocol, + "inbound substream opened" + ); + + let observed_addr = match self.peers.get(&peer) { + Some(endpoint) => Some(endpoint.address().to_vec()), + None => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + %protocol, + "inbound identify substream opened for peer who doesn't exist", + ); + None + } + }; + + let identify = identify_schema::Identify { + protocol_version: Some(self.protocol_version.clone()), + agent_version: Some(self.user_agent.clone()), + public_key: Some(self.public.to_protobuf_encoding()), + listen_addrs: self + .listen_addresses + .iter() + .map(|address| address.to_vec()) + .collect::>(), + observed_addr, + protocols: self.protocols.clone(), + }; + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?identify, + "sending identify response", + ); + + let mut msg = Vec::with_capacity(identify.encoded_len()); + identify.encode(&mut msg).expect("`msg` to have enough capacity"); + + self.pending_inbound.push(Box::pin(async move { + match tokio::time::timeout(Duration::from_secs(10), substream.send_framed(msg.into())) + .await + { + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "timed out while sending ipfs identify response", + ); + } + Ok(Err(error)) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "failed to send ipfs identify response", + ); + } + Ok(_) => {} + } + })) + } + + /// Outbound substream opened. + fn on_outbound_substream( + &mut self, + peer: PeerId, + protocol: ProtocolName, + substream_id: SubstreamId, + mut substream: Substream, + ) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?protocol, + ?substream_id, + "outbound substream opened" + ); + + self.pending_outbound.push(Box::pin(async move { + let payload = + match tokio::time::timeout(Duration::from_secs(10), substream.next()).await { + Err(_) => return Err(Error::Timeout), + Ok(None) => + return Err(Error::SubstreamError(SubstreamError::ReadFailure(Some( + substream_id, + )))), + Ok(Some(Err(error))) => return Err(error), + Ok(Some(Ok(payload))) => payload, + }; + + let info = identify_schema::Identify::decode(payload.to_vec().as_slice())?; + + tracing::trace!(target: LOG_TARGET, ?peer, ?info, "peer identified"); + + let listen_addresses = info + .listen_addrs + .iter() + .filter_map(|address| Multiaddr::try_from(address.clone()).ok()) + .collect(); + let observed_address = + info.observed_addr.map(|address| Multiaddr::try_from(address).ok()).flatten(); + let protocol_version = info.protocol_version; + let user_agent = info.agent_version; + + Ok(IdentifyResponse { + peer, + protocol_version, + user_agent, + supported_protocols: HashSet::from_iter(info.protocols), + observed_address, + listen_addresses, + }) + })); + } + + /// Start [`Identify`] event loop. + pub async fn run(mut self) { + tracing::debug!(target: LOG_TARGET, "starting identify event loop"); + + loop { + tokio::select! { + event = self.service.next() => match event { + None => return, + Some(TransportEvent::ConnectionEstablished { peer, endpoint }) => { + let _ = self.on_connection_established(peer, endpoint); + } + Some(TransportEvent::ConnectionClosed { peer }) => { + self.on_connection_closed(peer); + } + Some(TransportEvent::SubstreamOpened { + peer, + protocol, + direction, + substream, + .. + }) => match direction { + Direction::Inbound => self.on_inbound_substream(peer, protocol, substream), + Direction::Outbound(substream_id) => self.on_outbound_substream(peer, protocol, substream_id, substream), + }, + _ => {} + }, + _ = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => {} + event = self.pending_outbound.next(), if !self.pending_outbound.is_empty() => match event { + Some(Ok(response)) => { + let _ = self.tx + .send(IdentifyEvent::PeerIdentified { + peer: response.peer, + protocol_version: response.protocol_version, + user_agent: response.user_agent, + supported_protocols: response.supported_protocols.into_iter().map(From::from).collect(), + observed_address: response.observed_address.map_or(Multiaddr::empty(), |address| address), + listen_addresses: response.listen_addresses, + }) + .await; + } + Some(Err(error)) => tracing::debug!(target: LOG_TARGET, ?error, "failed to read ipfs identify response"), + None => return, + } + } + } + } } diff --git a/src/protocol/libp2p/kademlia/bucket.rs b/src/protocol/libp2p/kademlia/bucket.rs index ed55e7c6..88fd9266 100644 --- a/src/protocol/libp2p/kademlia/bucket.rs +++ b/src/protocol/libp2p/kademlia/bucket.rs @@ -22,163 +22,169 @@ //! Kademlia k-bucket implementation. use crate::{ - protocol::libp2p::kademlia::types::{ConnectionType, KademliaPeer, Key}, - PeerId, + protocol::libp2p::kademlia::types::{ConnectionType, KademliaPeer, Key}, + PeerId, }; /// K-bucket entry. #[derive(Debug, PartialEq, Eq)] pub enum KBucketEntry<'a> { - /// Entry points to local node. - LocalNode, + /// Entry points to local node. + LocalNode, - /// Occupied entry to a connected node. - Occupied(&'a mut KademliaPeer), + /// Occupied entry to a connected node. + Occupied(&'a mut KademliaPeer), - /// Vacant entry. - Vacant(&'a mut KademliaPeer), + /// Vacant entry. + Vacant(&'a mut KademliaPeer), - /// Entry not found and any present entry cannot be replaced. - NoSlot, + /// Entry not found and any present entry cannot be replaced. + NoSlot, } impl<'a> KBucketEntry<'a> { - /// Insert new entry into the entry if possible. - pub fn insert(&'a mut self, new: KademliaPeer) { - if let KBucketEntry::Vacant(old) = self { - old.peer = new.peer; - old.key = Key::from(new.peer); - old.addresses = new.addresses; - old.connection = new.connection; - } - } + /// Insert new entry into the entry if possible. + pub fn insert(&'a mut self, new: KademliaPeer) { + if let KBucketEntry::Vacant(old) = self { + old.peer = new.peer; + old.key = Key::from(new.peer); + old.addresses = new.addresses; + old.connection = new.connection; + } + } } /// Kademlia k-bucket. pub struct KBucket { - // TODO: store peers in a btreemap with increasing distance from local key? - nodes: Vec, + // TODO: store peers in a btreemap with increasing distance from local key? + nodes: Vec, } impl KBucket { - /// Create new [`KBucket`]. - pub fn new() -> Self { - Self { nodes: Vec::with_capacity(20) } - } - - /// Get entry into the bucket. - // TODO: this is horrible code - pub fn entry<'a, K: Clone>(&'a mut self, key: Key) -> KBucketEntry<'a> { - for i in 0..self.nodes.len() { - if &self.nodes[i].key == &key { - return KBucketEntry::Occupied(&mut self.nodes[i]); - } - } - - if self.nodes.len() < 20 { - self.nodes.push(KademliaPeer::new( - PeerId::random(), - vec![], - ConnectionType::NotConnected, - )); - let len = self.nodes.len() - 1; - return KBucketEntry::Vacant(&mut self.nodes[len]); - } - - for i in 0..self.nodes.len() { - match self.nodes[i].connection { - ConnectionType::NotConnected | ConnectionType::CannotConnect => { - return KBucketEntry::Vacant(&mut self.nodes[i]); - }, - _ => continue, - } - } - - KBucketEntry::NoSlot - } - - /// Get iterator over the k-bucket, sorting the k-bucket entries in increasing order - /// by distance. - pub fn closest_iter(&self, target: &Key) -> impl Iterator { - let mut nodes = self.nodes.clone(); - nodes.sort_by(|a, b| target.distance(&a.key).cmp(&target.distance(&b.key))); - nodes.into_iter().filter(|peer| !peer.addresses.is_empty()) - } + /// Create new [`KBucket`]. + pub fn new() -> Self { + Self { + nodes: Vec::with_capacity(20), + } + } + + /// Get entry into the bucket. + // TODO: this is horrible code + pub fn entry<'a, K: Clone>(&'a mut self, key: Key) -> KBucketEntry<'a> { + for i in 0..self.nodes.len() { + if &self.nodes[i].key == &key { + return KBucketEntry::Occupied(&mut self.nodes[i]); + } + } + + if self.nodes.len() < 20 { + self.nodes.push(KademliaPeer::new( + PeerId::random(), + vec![], + ConnectionType::NotConnected, + )); + let len = self.nodes.len() - 1; + return KBucketEntry::Vacant(&mut self.nodes[len]); + } + + for i in 0..self.nodes.len() { + match self.nodes[i].connection { + ConnectionType::NotConnected | ConnectionType::CannotConnect => { + return KBucketEntry::Vacant(&mut self.nodes[i]); + } + _ => continue, + } + } + + KBucketEntry::NoSlot + } + + /// Get iterator over the k-bucket, sorting the k-bucket entries in increasing order + /// by distance. + pub fn closest_iter(&self, target: &Key) -> impl Iterator { + let mut nodes = self.nodes.clone(); + nodes.sort_by(|a, b| target.distance(&a.key).cmp(&target.distance(&b.key))); + nodes.into_iter().filter(|peer| !peer.addresses.is_empty()) + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn closest_iter() { - let mut bucket = KBucket::new(); - - // add some random nodes to the bucket - let _ = (0..10) - .map(|_| { - let peer = PeerId::random(); - bucket.nodes.push(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); - - peer - }) - .collect::>(); - - let target = Key::from(PeerId::random()); - let mut iter = bucket.closest_iter(&target); - let mut prev = None; - - while let Some(node) = iter.next() { - if let Some(distance) = prev { - assert!(distance < target.distance(&node.key)); - } - - prev = Some(target.distance(&node.key)); - } - } - - #[test] - fn ignore_peers_with_no_addresses() { - let mut bucket = KBucket::new(); - - // add peers with no addresses to the bucket - let _ = (0..10) - .map(|_| { - let peer = PeerId::random(); - bucket.nodes.push(KademliaPeer::new(peer, vec![], ConnectionType::NotConnected)); - - peer - }) - .collect::>(); - - // add three peers with an address - let _ = (0..3) - .map(|_| { - let peer = PeerId::random(); - bucket.nodes.push(KademliaPeer::new( - peer, - vec!["/ip6/::/tcp/0".parse().unwrap()], - ConnectionType::Connected, - )); - - peer - }) - .collect::>(); - - let target = Key::from(PeerId::random()); - let mut iter = bucket.closest_iter(&target); - let mut prev = None; - let mut num_peers = 0usize; - - while let Some(node) = iter.next() { - if let Some(distance) = prev { - assert!(distance < target.distance(&node.key)); - } - - num_peers += 1; - prev = Some(target.distance(&node.key)); - } - - assert_eq!(num_peers, 3usize); - } + use super::*; + + #[test] + fn closest_iter() { + let mut bucket = KBucket::new(); + + // add some random nodes to the bucket + let _ = (0..10) + .map(|_| { + let peer = PeerId::random(); + bucket.nodes.push(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); + + peer + }) + .collect::>(); + + let target = Key::from(PeerId::random()); + let mut iter = bucket.closest_iter(&target); + let mut prev = None; + + while let Some(node) = iter.next() { + if let Some(distance) = prev { + assert!(distance < target.distance(&node.key)); + } + + prev = Some(target.distance(&node.key)); + } + } + + #[test] + fn ignore_peers_with_no_addresses() { + let mut bucket = KBucket::new(); + + // add peers with no addresses to the bucket + let _ = (0..10) + .map(|_| { + let peer = PeerId::random(); + bucket.nodes.push(KademliaPeer::new( + peer, + vec![], + ConnectionType::NotConnected, + )); + + peer + }) + .collect::>(); + + // add three peers with an address + let _ = (0..3) + .map(|_| { + let peer = PeerId::random(); + bucket.nodes.push(KademliaPeer::new( + peer, + vec!["/ip6/::/tcp/0".parse().unwrap()], + ConnectionType::Connected, + )); + + peer + }) + .collect::>(); + + let target = Key::from(PeerId::random()); + let mut iter = bucket.closest_iter(&target); + let mut prev = None; + let mut num_peers = 0usize; + + while let Some(node) = iter.next() { + if let Some(distance) = prev { + assert!(distance < target.distance(&node.key)); + } + + num_peers += 1; + prev = Some(target.distance(&node.key)); + } + + assert_eq!(num_peers, 3usize); + } } diff --git a/src/protocol/libp2p/kademlia/config.rs b/src/protocol/libp2p/kademlia/config.rs index df1b6095..0b7ca3d8 100644 --- a/src/protocol/libp2p/kademlia/config.rs +++ b/src/protocol/libp2p/kademlia/config.rs @@ -19,12 +19,12 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, - protocol::libp2p::kademlia::handle::{ - KademliaCommand, KademliaEvent, KademliaHandle, RoutingTableUpdateMode, - }, - types::protocol::ProtocolName, - PeerId, DEFAULT_CHANNEL_SIZE, + codec::ProtocolCodec, + protocol::libp2p::kademlia::handle::{ + KademliaCommand, KademliaEvent, KademliaHandle, RoutingTableUpdateMode, + }, + types::protocol::ProtocolName, + PeerId, DEFAULT_CHANNEL_SIZE, }; use multiaddr::Multiaddr; @@ -41,132 +41,137 @@ const REPLICATION_FACTOR: usize = 20usize; /// Kademlia configuration. #[derive(Debug)] pub struct Config { - // Protocol name. - // pub(crate) protocol: ProtocolName, - /// Protocol names. - pub(crate) protocol_names: Vec, + // Protocol name. + // pub(crate) protocol: ProtocolName, + /// Protocol names. + pub(crate) protocol_names: Vec, - /// Protocol codec. - pub(crate) codec: ProtocolCodec, + /// Protocol codec. + pub(crate) codec: ProtocolCodec, - /// Replication factor. - #[allow(unused)] - pub(super) replication_factor: usize, + /// Replication factor. + #[allow(unused)] + pub(super) replication_factor: usize, - /// Known peers. - pub(super) known_peers: HashMap>, + /// Known peers. + pub(super) known_peers: HashMap>, - /// Routing table update mode. - pub(super) update_mode: RoutingTableUpdateMode, + /// Routing table update mode. + pub(super) update_mode: RoutingTableUpdateMode, - /// TX channel for sending events to `KademliaHandle`. - pub(super) event_tx: Sender, + /// TX channel for sending events to `KademliaHandle`. + pub(super) event_tx: Sender, - /// RX channel for receiving commands from `KademliaHandle`. - pub(super) cmd_rx: Receiver, + /// RX channel for receiving commands from `KademliaHandle`. + pub(super) cmd_rx: Receiver, } impl Config { - fn new( - replication_factor: usize, - known_peers: HashMap>, - mut protocol_names: Vec, - update_mode: RoutingTableUpdateMode, - ) -> (Self, KademliaHandle) { - let (cmd_tx, cmd_rx) = channel(DEFAULT_CHANNEL_SIZE); - let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); - - // if no protocol names were provided, use the default protocol - if protocol_names.is_empty() { - protocol_names.push(ProtocolName::from(PROTOCOL_NAME)); - } - - ( - Config { - protocol_names, - update_mode, - codec: ProtocolCodec::UnsignedVarint(None), - replication_factor, - known_peers, - cmd_rx, - event_tx, - }, - KademliaHandle::new(cmd_tx, event_rx), - ) - } - - /// Build default Kademlia configuration. - pub fn default() -> (Self, KademliaHandle) { - Self::new(REPLICATION_FACTOR, HashMap::new(), Vec::new(), RoutingTableUpdateMode::Automatic) - } + fn new( + replication_factor: usize, + known_peers: HashMap>, + mut protocol_names: Vec, + update_mode: RoutingTableUpdateMode, + ) -> (Self, KademliaHandle) { + let (cmd_tx, cmd_rx) = channel(DEFAULT_CHANNEL_SIZE); + let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); + + // if no protocol names were provided, use the default protocol + if protocol_names.is_empty() { + protocol_names.push(ProtocolName::from(PROTOCOL_NAME)); + } + + ( + Config { + protocol_names, + update_mode, + codec: ProtocolCodec::UnsignedVarint(None), + replication_factor, + known_peers, + cmd_rx, + event_tx, + }, + KademliaHandle::new(cmd_tx, event_rx), + ) + } + + /// Build default Kademlia configuration. + pub fn default() -> (Self, KademliaHandle) { + Self::new( + REPLICATION_FACTOR, + HashMap::new(), + Vec::new(), + RoutingTableUpdateMode::Automatic, + ) + } } /// Configuration builder for Kademlia. #[derive(Debug)] pub struct ConfigBuilder { - /// Replication factor. - pub(super) replication_factor: usize, + /// Replication factor. + pub(super) replication_factor: usize, - /// Routing table update mode. - pub(super) update_mode: RoutingTableUpdateMode, + /// Routing table update mode. + pub(super) update_mode: RoutingTableUpdateMode, - /// Known peers. - pub(super) known_peers: HashMap>, + /// Known peers. + pub(super) known_peers: HashMap>, - /// Protocol names. - pub(super) protocol_names: Vec, + /// Protocol names. + pub(super) protocol_names: Vec, } impl ConfigBuilder { - /// Create new [`ConfigBuilder`]. - pub fn new() -> Self { - Self { - replication_factor: REPLICATION_FACTOR, - known_peers: HashMap::new(), - protocol_names: Vec::new(), - update_mode: RoutingTableUpdateMode::Automatic, - } - } - - /// Set replication factor. - pub fn with_replication_factor(mut self, replication_factor: usize) -> Self { - self.replication_factor = replication_factor; - self - } - - /// Seed Kademlia with one or more known peers. - pub fn with_known_peers(mut self, peers: HashMap>) -> Self { - self.known_peers = peers; - self - } - - /// Set routing table update mode. - pub fn with_routing_table_update_mode(mut self, mode: RoutingTableUpdateMode) -> Self { - self.update_mode = mode; - self - } - - /// Set Kademlia protocol names, overriding the default protocol name. - /// - /// The order of the protocol names signifies preference so if, for example, there are two - /// protocols: - /// * `/kad/2.0.0` - /// * `/kad/1.0.0` - /// - /// Where `/kad/2.0.0` is the preferred version, then that should be in `protocol_names` before - /// `/kad/1.0.0`. - pub fn with_protocol_names(mut self, protocol_names: Vec) -> Self { - self.protocol_names = protocol_names; - self - } - - /// Build Kademlia [`Config`]. - pub fn build(self) -> (Config, KademliaHandle) { - Config::new( - self.replication_factor, - self.known_peers, - self.protocol_names, - self.update_mode, - ) - } + /// Create new [`ConfigBuilder`]. + pub fn new() -> Self { + Self { + replication_factor: REPLICATION_FACTOR, + known_peers: HashMap::new(), + protocol_names: Vec::new(), + update_mode: RoutingTableUpdateMode::Automatic, + } + } + + /// Set replication factor. + pub fn with_replication_factor(mut self, replication_factor: usize) -> Self { + self.replication_factor = replication_factor; + self + } + + /// Seed Kademlia with one or more known peers. + pub fn with_known_peers(mut self, peers: HashMap>) -> Self { + self.known_peers = peers; + self + } + + /// Set routing table update mode. + pub fn with_routing_table_update_mode(mut self, mode: RoutingTableUpdateMode) -> Self { + self.update_mode = mode; + self + } + + /// Set Kademlia protocol names, overriding the default protocol name. + /// + /// The order of the protocol names signifies preference so if, for example, there are two + /// protocols: + /// * `/kad/2.0.0` + /// * `/kad/1.0.0` + /// + /// Where `/kad/2.0.0` is the preferred version, then that should be in `protocol_names` before + /// `/kad/1.0.0`. + pub fn with_protocol_names(mut self, protocol_names: Vec) -> Self { + self.protocol_names = protocol_names; + self + } + + /// Build Kademlia [`Config`]. + pub fn build(self) -> (Config, KademliaHandle) { + Config::new( + self.replication_factor, + self.known_peers, + self.protocol_names, + self.update_mode, + ) + } } diff --git a/src/protocol/libp2p/kademlia/executor.rs b/src/protocol/libp2p/kademlia/executor.rs index 8c51cb96..02701c76 100644 --- a/src/protocol/libp2p/kademlia/executor.rs +++ b/src/protocol/libp2p/kademlia/executor.rs @@ -24,9 +24,9 @@ use bytes::{Bytes, BytesMut}; use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; use std::{ - pin::Pin, - task::{Context, Poll}, - time::Duration, + pin::Pin, + task::{Context, Poll}, + time::Duration, }; /// Read timeout for inbound messages. @@ -35,298 +35,346 @@ const READ_TIMEOUT: Duration = Duration::from_secs(15); /// Query result. #[derive(Debug)] pub enum QueryResult { - /// Message was sent to remote peer successfully. - SendSuccess { - /// Substream. - substream: Substream, - }, - - /// Message was read from the remote peer successfully. - ReadSuccess { - /// Substream. - substream: Substream, - - /// Read message. - message: BytesMut, - }, - - /// Timeout while reading a response from the substream. - Timeout, - - /// Substream was closed wile reading/writing message to remote peer. - SubstreamClosed, + /// Message was sent to remote peer successfully. + SendSuccess { + /// Substream. + substream: Substream, + }, + + /// Message was read from the remote peer successfully. + ReadSuccess { + /// Substream. + substream: Substream, + + /// Read message. + message: BytesMut, + }, + + /// Timeout while reading a response from the substream. + Timeout, + + /// Substream was closed wile reading/writing message to remote peer. + SubstreamClosed, } /// Query result. #[derive(Debug)] pub struct QueryContext { - /// Peer ID. - pub peer: PeerId, + /// Peer ID. + pub peer: PeerId, - /// Query ID. - pub query_id: Option, + /// Query ID. + pub query_id: Option, - /// Query result. - pub result: QueryResult, + /// Query result. + pub result: QueryResult, } /// Query executor. pub struct QueryExecutor { - /// Pending futures. - futures: FuturesUnordered>, + /// Pending futures. + futures: FuturesUnordered>, } impl QueryExecutor { - /// Create new [`QueryExecutor`] - pub fn new() -> Self { - Self { futures: FuturesUnordered::new() } - } - - /// Send message to remote peer. - pub fn send_message(&mut self, peer: PeerId, message: Bytes, mut substream: Substream) { - self.futures.push(Box::pin(async move { - match substream.send_framed(message).await { - Ok(_) => - return QueryContext { - peer, - query_id: None, - result: QueryResult::SendSuccess { substream }, - }, - Err(_) => - return QueryContext { - peer, - query_id: None, - result: QueryResult::SubstreamClosed, - }, - } - })); - } - - /// Read message from remote peer with timeout. - pub fn read_message( - &mut self, - peer: PeerId, - query_id: Option, - mut substream: Substream, - ) { - self.futures.push(Box::pin(async move { - match tokio::time::timeout(READ_TIMEOUT, substream.next()).await { - Err(_) => return QueryContext { peer, query_id, result: QueryResult::Timeout }, - Ok(Some(Ok(message))) => - return QueryContext { - peer, - query_id, - result: QueryResult::ReadSuccess { substream, message }, - }, - Ok(None) | Ok(Some(Err(_))) => - return QueryContext { peer, query_id, result: QueryResult::SubstreamClosed }, - } - })); - } - - /// Send request to remote peer and read response. - pub fn send_request_read_response( - &mut self, - peer: PeerId, - query_id: Option, - message: Bytes, - mut substream: Substream, - ) { - self.futures.push(Box::pin(async move { - if let Err(_) = substream.send_framed(message).await { - let _ = substream.close().await; - return QueryContext { peer, query_id, result: QueryResult::SubstreamClosed }; - } - - match tokio::time::timeout(READ_TIMEOUT, substream.next()).await { - Err(_) => return QueryContext { peer, query_id, result: QueryResult::Timeout }, - Ok(Some(Ok(message))) => - return QueryContext { - peer, - query_id, - result: QueryResult::ReadSuccess { substream, message }, - }, - Ok(None) | Ok(Some(Err(_))) => - return QueryContext { peer, query_id, result: QueryResult::SubstreamClosed }, - } - })); - } + /// Create new [`QueryExecutor`] + pub fn new() -> Self { + Self { + futures: FuturesUnordered::new(), + } + } + + /// Send message to remote peer. + pub fn send_message(&mut self, peer: PeerId, message: Bytes, mut substream: Substream) { + self.futures.push(Box::pin(async move { + match substream.send_framed(message).await { + Ok(_) => + return QueryContext { + peer, + query_id: None, + result: QueryResult::SendSuccess { substream }, + }, + Err(_) => + return QueryContext { + peer, + query_id: None, + result: QueryResult::SubstreamClosed, + }, + } + })); + } + + /// Read message from remote peer with timeout. + pub fn read_message( + &mut self, + peer: PeerId, + query_id: Option, + mut substream: Substream, + ) { + self.futures.push(Box::pin(async move { + match tokio::time::timeout(READ_TIMEOUT, substream.next()).await { + Err(_) => + return QueryContext { + peer, + query_id, + result: QueryResult::Timeout, + }, + Ok(Some(Ok(message))) => + return QueryContext { + peer, + query_id, + result: QueryResult::ReadSuccess { substream, message }, + }, + Ok(None) | Ok(Some(Err(_))) => + return QueryContext { + peer, + query_id, + result: QueryResult::SubstreamClosed, + }, + } + })); + } + + /// Send request to remote peer and read response. + pub fn send_request_read_response( + &mut self, + peer: PeerId, + query_id: Option, + message: Bytes, + mut substream: Substream, + ) { + self.futures.push(Box::pin(async move { + if let Err(_) = substream.send_framed(message).await { + let _ = substream.close().await; + return QueryContext { + peer, + query_id, + result: QueryResult::SubstreamClosed, + }; + } + + match tokio::time::timeout(READ_TIMEOUT, substream.next()).await { + Err(_) => + return QueryContext { + peer, + query_id, + result: QueryResult::Timeout, + }, + Ok(Some(Ok(message))) => + return QueryContext { + peer, + query_id, + result: QueryResult::ReadSuccess { substream, message }, + }, + Ok(None) | Ok(Some(Err(_))) => + return QueryContext { + peer, + query_id, + result: QueryResult::SubstreamClosed, + }, + } + })); + } } impl Stream for QueryExecutor { - type Item = QueryContext; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.futures.is_empty() { - true => Poll::Pending, - false => self.futures.poll_next_unpin(cx), - } - } + type Item = QueryContext; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.futures.is_empty() { + true => Poll::Pending, + false => self.futures.poll_next_unpin(cx), + } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{mock::substream::MockSubstream, types::SubstreamId}; - - #[tokio::test] - async fn substream_read_timeout() { - let mut executor = QueryExecutor::new(); - let peer = PeerId::random(); - let mut substream = MockSubstream::new(); - substream.expect_poll_next().returning(|_| Poll::Pending); - let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)); - - executor.read_message(peer, None, substream); - - match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { - Ok(Some(QueryContext { peer: queried_peer, query_id, result })) => { - assert_eq!(peer, queried_peer); - assert!(query_id.is_none()); - assert!(std::matches!(result, QueryResult::Timeout)); - }, - result => panic!("invalid result received: {result:?}"), - } - } - - #[tokio::test] - async fn substream_read_substream_closed() { - let mut executor = QueryExecutor::new(); - let peer = PeerId::random(); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Err(crate::Error::Unknown)))); - - executor.read_message( - peer, - Some(QueryId(1338)), - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), - ); - - match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { - Ok(Some(QueryContext { peer: queried_peer, query_id, result })) => { - assert_eq!(peer, queried_peer); - assert_eq!(query_id, Some(QueryId(1338))); - assert!(std::matches!(result, QueryResult::SubstreamClosed)); - }, - result => panic!("invalid result received: {result:?}"), - } - } - - #[tokio::test] - async fn send_succeeds_no_message_read() { - let mut executor = QueryExecutor::new(); - let peer = PeerId::random(); - - // prepare substream which succeeds in sending the message but closes right after - let mut substream = MockSubstream::new(); - substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream.expect_start_send().times(1).return_once(|_| Ok(())); - substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Err(crate::Error::Unknown)))); - - executor.send_request_read_response( - peer, - Some(QueryId(1337)), - Bytes::from_static(b"hello, world"), - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), - ); - - match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { - Ok(Some(QueryContext { peer: queried_peer, query_id, result })) => { - assert_eq!(peer, queried_peer); - assert_eq!(query_id, Some(QueryId(1337))); - assert!(std::matches!(result, QueryResult::SubstreamClosed)); - }, - result => panic!("invalid result received: {result:?}"), - } - } - - #[tokio::test] - async fn send_fails_no_message_read() { - let mut executor = QueryExecutor::new(); - let peer = PeerId::random(); - - // prepare substream which succeeds in sending the message but closes right after - let mut substream = MockSubstream::new(); - substream - .expect_poll_ready() - .times(1) - .return_once(|_| Poll::Ready(Err(crate::Error::Unknown))); - substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - - executor.send_request_read_response( - peer, - Some(QueryId(1337)), - Bytes::from_static(b"hello, world"), - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), - ); - - match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { - Ok(Some(QueryContext { peer: queried_peer, query_id, result })) => { - assert_eq!(peer, queried_peer); - assert_eq!(query_id, Some(QueryId(1337))); - assert!(std::matches!(result, QueryResult::SubstreamClosed)); - }, - result => panic!("invalid result received: {result:?}"), - } - } - - #[tokio::test] - async fn read_message_timeout() { - let mut executor = QueryExecutor::new(); - let peer = PeerId::random(); - - // prepare substream which succeeds in sending the message but closes right after - let mut substream = MockSubstream::new(); - substream.expect_poll_next().returning(|_| Poll::Pending); - - executor.read_message( - peer, - Some(QueryId(1336)), - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), - ); - - match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { - Ok(Some(QueryContext { peer: queried_peer, query_id, result })) => { - assert_eq!(peer, queried_peer); - assert_eq!(query_id, Some(QueryId(1336))); - assert!(std::matches!(result, QueryResult::Timeout)); - }, - result => panic!("invalid result received: {result:?}"), - } - } - - #[tokio::test] - async fn read_message_substream_closed() { - let mut executor = QueryExecutor::new(); - let peer = PeerId::random(); - - // prepare substream which succeeds in sending the message but closes right after - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Err(crate::Error::Unknown)))); - - executor.read_message( - peer, - Some(QueryId(1335)), - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), - ); - - match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { - Ok(Some(QueryContext { peer: queried_peer, query_id, result })) => { - assert_eq!(peer, queried_peer); - assert_eq!(query_id, Some(QueryId(1335))); - assert!(std::matches!(result, QueryResult::SubstreamClosed)); - }, - result => panic!("invalid result received: {result:?}"), - } - } + use super::*; + use crate::{mock::substream::MockSubstream, types::SubstreamId}; + + #[tokio::test] + async fn substream_read_timeout() { + let mut executor = QueryExecutor::new(); + let peer = PeerId::random(); + let mut substream = MockSubstream::new(); + substream.expect_poll_next().returning(|_| Poll::Pending); + let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)); + + executor.read_message(peer, None, substream); + + match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { + Ok(Some(QueryContext { + peer: queried_peer, + query_id, + result, + })) => { + assert_eq!(peer, queried_peer); + assert!(query_id.is_none()); + assert!(std::matches!(result, QueryResult::Timeout)); + } + result => panic!("invalid result received: {result:?}"), + } + } + + #[tokio::test] + async fn substream_read_substream_closed() { + let mut executor = QueryExecutor::new(); + let peer = PeerId::random(); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Err(crate::Error::Unknown)))); + + executor.read_message( + peer, + Some(QueryId(1338)), + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), + ); + + match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { + Ok(Some(QueryContext { + peer: queried_peer, + query_id, + result, + })) => { + assert_eq!(peer, queried_peer); + assert_eq!(query_id, Some(QueryId(1338))); + assert!(std::matches!(result, QueryResult::SubstreamClosed)); + } + result => panic!("invalid result received: {result:?}"), + } + } + + #[tokio::test] + async fn send_succeeds_no_message_read() { + let mut executor = QueryExecutor::new(); + let peer = PeerId::random(); + + // prepare substream which succeeds in sending the message but closes right after + let mut substream = MockSubstream::new(); + substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream.expect_start_send().times(1).return_once(|_| Ok(())); + substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Err(crate::Error::Unknown)))); + + executor.send_request_read_response( + peer, + Some(QueryId(1337)), + Bytes::from_static(b"hello, world"), + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), + ); + + match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { + Ok(Some(QueryContext { + peer: queried_peer, + query_id, + result, + })) => { + assert_eq!(peer, queried_peer); + assert_eq!(query_id, Some(QueryId(1337))); + assert!(std::matches!(result, QueryResult::SubstreamClosed)); + } + result => panic!("invalid result received: {result:?}"), + } + } + + #[tokio::test] + async fn send_fails_no_message_read() { + let mut executor = QueryExecutor::new(); + let peer = PeerId::random(); + + // prepare substream which succeeds in sending the message but closes right after + let mut substream = MockSubstream::new(); + substream + .expect_poll_ready() + .times(1) + .return_once(|_| Poll::Ready(Err(crate::Error::Unknown))); + substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + + executor.send_request_read_response( + peer, + Some(QueryId(1337)), + Bytes::from_static(b"hello, world"), + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), + ); + + match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { + Ok(Some(QueryContext { + peer: queried_peer, + query_id, + result, + })) => { + assert_eq!(peer, queried_peer); + assert_eq!(query_id, Some(QueryId(1337))); + assert!(std::matches!(result, QueryResult::SubstreamClosed)); + } + result => panic!("invalid result received: {result:?}"), + } + } + + #[tokio::test] + async fn read_message_timeout() { + let mut executor = QueryExecutor::new(); + let peer = PeerId::random(); + + // prepare substream which succeeds in sending the message but closes right after + let mut substream = MockSubstream::new(); + substream.expect_poll_next().returning(|_| Poll::Pending); + + executor.read_message( + peer, + Some(QueryId(1336)), + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), + ); + + match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { + Ok(Some(QueryContext { + peer: queried_peer, + query_id, + result, + })) => { + assert_eq!(peer, queried_peer); + assert_eq!(query_id, Some(QueryId(1336))); + assert!(std::matches!(result, QueryResult::Timeout)); + } + result => panic!("invalid result received: {result:?}"), + } + } + + #[tokio::test] + async fn read_message_substream_closed() { + let mut executor = QueryExecutor::new(); + let peer = PeerId::random(); + + // prepare substream which succeeds in sending the message but closes right after + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Err(crate::Error::Unknown)))); + + executor.read_message( + peer, + Some(QueryId(1335)), + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), + ); + + match tokio::time::timeout(Duration::from_secs(20), executor.next()).await { + Ok(Some(QueryContext { + peer: queried_peer, + query_id, + result, + })) => { + assert_eq!(peer, queried_peer); + assert_eq!(query_id, Some(QueryId(1335))); + assert!(std::matches!(result, QueryResult::SubstreamClosed)); + } + result => panic!("invalid result received: {result:?}"), + } + } } diff --git a/src/protocol/libp2p/kademlia/handle.rs b/src/protocol/libp2p/kademlia/handle.rs index 51be3795..2d8e913c 100644 --- a/src/protocol/libp2p/kademlia/handle.rs +++ b/src/protocol/libp2p/kademlia/handle.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::libp2p::kademlia::{QueryId, Record, RecordKey}, - PeerId, + protocol::libp2p::kademlia::{QueryId, Record, RecordKey}, + PeerId, }; use futures::Stream; @@ -28,9 +28,9 @@ use multiaddr::Multiaddr; use tokio::sync::mpsc::{Receiver, Sender}; use std::{ - num::NonZeroUsize, - pin::Pin, - task::{Context, Poll}, + num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, }; /// Quorum. @@ -39,217 +39,232 @@ use std::{ /// in order for the query to be considered successful. #[derive(Debug, Copy, Clone)] pub enum Quorum { - /// All peers must be successfully contacted. - All, + /// All peers must be successfully contacted. + All, - /// One peer must be successfully contacted. - One, + /// One peer must be successfully contacted. + One, - /// `N` peer must be successfully contacted. - N(NonZeroUsize), + /// `N` peer must be successfully contacted. + N(NonZeroUsize), } /// Routing table update mode. #[derive(Debug, Copy, Clone)] pub enum RoutingTableUpdateMode { - /// Don't insert discovered peers automatically to the routing tables but - /// allow user to do that by calling [`KademliaHandle::add_known_peer()`]. - Manual, + /// Don't insert discovered peers automatically to the routing tables but + /// allow user to do that by calling [`KademliaHandle::add_known_peer()`]. + Manual, - /// Automatically add all discovered peers to routing tables. - Automatic, + /// Automatically add all discovered peers to routing tables. + Automatic, } /// Kademlia commands. #[derive(Debug)] pub(crate) enum KademliaCommand { - /// Add known peer. - AddKnownPeer { - /// Peer ID. - peer: PeerId, - - /// Addresses of peer. - addresses: Vec, - }, - - /// Send `FIND_NODE` message. - FindNode { - /// Peer ID. - peer: PeerId, - - /// Query ID for the query. - query_id: QueryId, - }, - - /// Store record to DHT. - PutRecord { - /// Record. - record: Record, - - /// Query ID for the query. - query_id: QueryId, - }, - - /// Get record from DHT. - GetRecord { - /// Record key. - key: RecordKey, - - /// [`Quorum`] for the query. - quorum: Quorum, - - /// Query ID for the query. - query_id: QueryId, - }, + /// Add known peer. + AddKnownPeer { + /// Peer ID. + peer: PeerId, + + /// Addresses of peer. + addresses: Vec, + }, + + /// Send `FIND_NODE` message. + FindNode { + /// Peer ID. + peer: PeerId, + + /// Query ID for the query. + query_id: QueryId, + }, + + /// Store record to DHT. + PutRecord { + /// Record. + record: Record, + + /// Query ID for the query. + query_id: QueryId, + }, + + /// Get record from DHT. + GetRecord { + /// Record key. + key: RecordKey, + + /// [`Quorum`] for the query. + quorum: Quorum, + + /// Query ID for the query. + query_id: QueryId, + }, } /// Kademlia events. #[derive(Debug, Clone)] pub enum KademliaEvent { - /// Result for the issued `FIND_NODE` query. - FindNodeSuccess { - /// Query ID. - query_id: QueryId, - - /// Target of the query - target: PeerId, - - /// Found nodes and their addresses. - peers: Vec<(PeerId, Vec)>, - }, - - /// Routing table update. - /// - /// Kademlia has discovered one or more peers that should be added to the routing table. - /// If [`RoutingTableUpdateMode`] is `Automatic`, user can ignore this event unless some - /// upper-level protocols has user for this information. - /// - /// If the mode was set to `Manual`, user should call [`KademliaHandle::add_known_peer()`] - /// in order to add the peers to routing table. - RoutingTableUpdate { - /// Discovered peers. - peers: Vec, - }, - - /// `GET_VALUE` query succeeded. - GetRecordSuccess { - /// Query ID. - query_id: QueryId, - - /// Found record. - record: Record, - }, - - /// `PUT_VALUE` query succeeded. - PutRecordSucess { - /// Query ID. - query_id: QueryId, - - /// Record key. - key: RecordKey, - }, - - /// Query failed. - QueryFailed { - /// Query ID. - query_id: QueryId, - }, + /// Result for the issued `FIND_NODE` query. + FindNodeSuccess { + /// Query ID. + query_id: QueryId, + + /// Target of the query + target: PeerId, + + /// Found nodes and their addresses. + peers: Vec<(PeerId, Vec)>, + }, + + /// Routing table update. + /// + /// Kademlia has discovered one or more peers that should be added to the routing table. + /// If [`RoutingTableUpdateMode`] is `Automatic`, user can ignore this event unless some + /// upper-level protocols has user for this information. + /// + /// If the mode was set to `Manual`, user should call [`KademliaHandle::add_known_peer()`] + /// in order to add the peers to routing table. + RoutingTableUpdate { + /// Discovered peers. + peers: Vec, + }, + + /// `GET_VALUE` query succeeded. + GetRecordSuccess { + /// Query ID. + query_id: QueryId, + + /// Found record. + record: Record, + }, + + /// `PUT_VALUE` query succeeded. + PutRecordSucess { + /// Query ID. + query_id: QueryId, + + /// Record key. + key: RecordKey, + }, + + /// Query failed. + QueryFailed { + /// Query ID. + query_id: QueryId, + }, } /// Handle for communicating with the Kademlia protocol. pub struct KademliaHandle { - /// TX channel for sending commands to `Kademlia`. - cmd_tx: Sender, + /// TX channel for sending commands to `Kademlia`. + cmd_tx: Sender, - /// RX channel for receiving events from `Kademlia`. - event_rx: Receiver, + /// RX channel for receiving events from `Kademlia`. + event_rx: Receiver, - /// Next query ID. - next_query_id: usize, + /// Next query ID. + next_query_id: usize, } impl KademliaHandle { - /// Create new [`KademliaHandle`]. - pub(super) fn new(cmd_tx: Sender, event_rx: Receiver) -> Self { - Self { cmd_tx, event_rx, next_query_id: 0usize } - } - - /// Allocate next query ID. - fn next_query_id(&mut self) -> QueryId { - let query_id = self.next_query_id; - self.next_query_id += 1; - - QueryId(query_id) - } - - /// Add known peer. - pub async fn add_known_peer(&self, peer: PeerId, addresses: Vec) { - let _ = self.cmd_tx.send(KademliaCommand::AddKnownPeer { peer, addresses }).await; - } - - /// Send `FIND_NODE` query to known peers. - pub async fn find_node(&mut self, peer: PeerId) -> QueryId { - let query_id = self.next_query_id(); - let _ = self.cmd_tx.send(KademliaCommand::FindNode { peer, query_id }).await; - - query_id - } - - /// Store record to DHT. - pub async fn put_record(&mut self, record: Record) -> QueryId { - let query_id = self.next_query_id(); - let _ = self.cmd_tx.send(KademliaCommand::PutRecord { record, query_id }).await; - - query_id - } - - /// Get record from DHT. - pub async fn get_record(&mut self, key: RecordKey, quorum: Quorum) -> QueryId { - let query_id = self.next_query_id(); - let _ = self.cmd_tx.send(KademliaCommand::GetRecord { key, quorum, query_id }).await; - - query_id - } - - /// Try to add known peer and if the channel is clogged, return an error. - pub fn try_add_known_peer(&self, peer: PeerId, addresses: Vec) -> Result<(), ()> { - self.cmd_tx - .try_send(KademliaCommand::AddKnownPeer { peer, addresses }) - .map_err(|_| ()) - } - - /// Try to initiate `FIND_NODE` query and if the channel is clogged, return an error. - pub fn try_find_node(&mut self, peer: PeerId) -> Result { - let query_id = self.next_query_id(); - self.cmd_tx - .try_send(KademliaCommand::FindNode { peer, query_id }) - .map(|_| query_id) - .map_err(|_| ()) - } - - /// Try to initiate `PUT_VALUE` query and if the channel is clogged, return an error. - pub fn try_put_record(&mut self, record: Record) -> Result { - let query_id = self.next_query_id(); - self.cmd_tx - .try_send(KademliaCommand::PutRecord { record, query_id }) - .map(|_| query_id) - .map_err(|_| ()) - } - - /// Try to initiate `GET_VALUE` query and if the channel is clogged, return an error. - pub fn try_get_record(&mut self, key: RecordKey, quorum: Quorum) -> Result { - let query_id = self.next_query_id(); - self.cmd_tx - .try_send(KademliaCommand::GetRecord { key, quorum, query_id }) - .map(|_| query_id) - .map_err(|_| ()) - } + /// Create new [`KademliaHandle`]. + pub(super) fn new(cmd_tx: Sender, event_rx: Receiver) -> Self { + Self { + cmd_tx, + event_rx, + next_query_id: 0usize, + } + } + + /// Allocate next query ID. + fn next_query_id(&mut self) -> QueryId { + let query_id = self.next_query_id; + self.next_query_id += 1; + + QueryId(query_id) + } + + /// Add known peer. + pub async fn add_known_peer(&self, peer: PeerId, addresses: Vec) { + let _ = self.cmd_tx.send(KademliaCommand::AddKnownPeer { peer, addresses }).await; + } + + /// Send `FIND_NODE` query to known peers. + pub async fn find_node(&mut self, peer: PeerId) -> QueryId { + let query_id = self.next_query_id(); + let _ = self.cmd_tx.send(KademliaCommand::FindNode { peer, query_id }).await; + + query_id + } + + /// Store record to DHT. + pub async fn put_record(&mut self, record: Record) -> QueryId { + let query_id = self.next_query_id(); + let _ = self.cmd_tx.send(KademliaCommand::PutRecord { record, query_id }).await; + + query_id + } + + /// Get record from DHT. + pub async fn get_record(&mut self, key: RecordKey, quorum: Quorum) -> QueryId { + let query_id = self.next_query_id(); + let _ = self + .cmd_tx + .send(KademliaCommand::GetRecord { + key, + quorum, + query_id, + }) + .await; + + query_id + } + + /// Try to add known peer and if the channel is clogged, return an error. + pub fn try_add_known_peer(&self, peer: PeerId, addresses: Vec) -> Result<(), ()> { + self.cmd_tx + .try_send(KademliaCommand::AddKnownPeer { peer, addresses }) + .map_err(|_| ()) + } + + /// Try to initiate `FIND_NODE` query and if the channel is clogged, return an error. + pub fn try_find_node(&mut self, peer: PeerId) -> Result { + let query_id = self.next_query_id(); + self.cmd_tx + .try_send(KademliaCommand::FindNode { peer, query_id }) + .map(|_| query_id) + .map_err(|_| ()) + } + + /// Try to initiate `PUT_VALUE` query and if the channel is clogged, return an error. + pub fn try_put_record(&mut self, record: Record) -> Result { + let query_id = self.next_query_id(); + self.cmd_tx + .try_send(KademliaCommand::PutRecord { record, query_id }) + .map(|_| query_id) + .map_err(|_| ()) + } + + /// Try to initiate `GET_VALUE` query and if the channel is clogged, return an error. + pub fn try_get_record(&mut self, key: RecordKey, quorum: Quorum) -> Result { + let query_id = self.next_query_id(); + self.cmd_tx + .try_send(KademliaCommand::GetRecord { + key, + quorum, + query_id, + }) + .map(|_| query_id) + .map_err(|_| ()) + } } impl Stream for KademliaHandle { - type Item = KademliaEvent; + type Item = KademliaEvent; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.event_rx.poll_recv(cx) - } + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.event_rx.poll_recv(cx) + } } diff --git a/src/protocol/libp2p/kademlia/message.rs b/src/protocol/libp2p/kademlia/message.rs index d0a29a8c..776afb29 100644 --- a/src/protocol/libp2p/kademlia/message.rs +++ b/src/protocol/libp2p/kademlia/message.rs @@ -19,9 +19,9 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol::libp2p::kademlia::{ - record::{Key as RecordKey, Record}, - schema, - types::KademliaPeer, + record::{Key as RecordKey, Record}, + schema, + types::KademliaPeer, }; use bytes::{Bytes, BytesMut}; @@ -33,177 +33,182 @@ const LOG_TARGET: &str = "litep2p::ipfs::kademlia::message"; /// Kademlia message. #[derive(Debug, Clone)] pub enum KademliaMessage { - /// `FIND_NODE` message. - FindNode { - /// Query target. - target: Vec, - - /// Found peers. - peers: Vec, - }, - - /// Kademlia `PUT_VALUE` message. - PutValue { - /// Record. - record: Record, - }, - - /// `GET_VALUE` message. - GetRecord { - /// Key. - key: Option, - - /// Record. - record: Option, - - /// Peers closest to key. - peers: Vec, - }, + /// `FIND_NODE` message. + FindNode { + /// Query target. + target: Vec, + + /// Found peers. + peers: Vec, + }, + + /// Kademlia `PUT_VALUE` message. + PutValue { + /// Record. + record: Record, + }, + + /// `GET_VALUE` message. + GetRecord { + /// Key. + key: Option, + + /// Record. + record: Option, + + /// Peers closest to key. + peers: Vec, + }, } impl KademliaMessage { - /// Create `FIND_NODE` message for `peer`. - pub fn find_node>>(key: T) -> Bytes { - let message = schema::kademlia::Message { - key: key.into(), - r#type: schema::kademlia::MessageType::FindNode.into(), - cluster_level_raw: 10, - ..Default::default() - }; - - let mut buf = BytesMut::with_capacity(message.encoded_len()); - message.encode(&mut buf).expect("Vec to provide needed capacity"); - - buf.freeze() - } - - /// Create `PUT_VALUE` message for `record`. - // TODO: set ttl - pub fn put_value(record: Record) -> Bytes { - let message = schema::kademlia::Message { - key: record.key.clone().into(), - r#type: schema::kademlia::MessageType::PutValue.into(), - record: Some(schema::kademlia::Record { - key: record.key.into(), - value: record.value, - ..Default::default() - }), - cluster_level_raw: 10, - ..Default::default() - }; - - let mut buf = BytesMut::with_capacity(message.encoded_len()); - message.encode(&mut buf).expect("BytesMut to provide needed capacity"); - - buf.freeze() - } - - /// Create `GET_VALUE` message for `record`. - pub fn get_record(key: RecordKey) -> Bytes { - let message = schema::kademlia::Message { - key: key.clone().into(), - r#type: schema::kademlia::MessageType::GetValue.into(), - cluster_level_raw: 10, - ..Default::default() - }; - - let mut buf = BytesMut::with_capacity(message.encoded_len()); - message.encode(&mut buf).expect("BytesMut to provide needed capacity"); - - buf.freeze() - } - - /// Create `FIND_NODE` response. - pub fn find_node_response>(key: K, peers: Vec) -> Vec { - let message = schema::kademlia::Message { - key: key.as_ref().to_vec(), - cluster_level_raw: 10, - r#type: schema::kademlia::MessageType::FindNode.into(), - closer_peers: peers.iter().map(|peer| peer.into()).collect(), - ..Default::default() - }; - - let mut buf = Vec::with_capacity(message.encoded_len()); - message.encode(&mut buf).expect("Vec to provide needed capacity"); - - buf - } - - /// Create `PUT_VALUE` response. - pub fn get_value_response( - key: RecordKey, - peers: Vec, - record: Option, - ) -> Vec { - let message = schema::kademlia::Message { - key: key.to_vec(), - cluster_level_raw: 10, - r#type: schema::kademlia::MessageType::GetValue.into(), - closer_peers: peers.iter().map(|peer| peer.into()).collect(), - record: record.map(|record| schema::kademlia::Record { - key: record.key.to_vec(), - value: record.value, - ..Default::default() - }), - ..Default::default() - }; - - let mut buf = Vec::with_capacity(message.encoded_len()); - message.encode(&mut buf).expect("Vec to provide needed capacity"); - - buf - } - - /// Get [`KademliaMessage`] from bytes. - pub fn from_bytes(bytes: BytesMut) -> Option { - match schema::kademlia::Message::decode(bytes) { - Ok(message) => match message.r#type { - 4 => { - let peers = message - .closer_peers - .iter() - .filter_map(|peer| KademliaPeer::try_from(peer).ok()) - .collect(); - - Some(Self::FindNode { target: message.key, peers }) - }, - 0 => { - let record = message.record?; - - Some(Self::PutValue { record: Record::new(record.key, record.value) }) - }, - 1 => { - let key = match message.key.is_empty() { - true => message - .record - .as_ref() - .map(|record| { - (!record.key.is_empty()) - .then_some(RecordKey::from(record.key.clone())) - }) - .flatten(), - false => Some(RecordKey::from(message.key.clone())), - }; - - Some(Self::GetRecord { - key, - record: message.record.map(|record| Record::new(record.key, record.value)), - peers: message - .closer_peers - .iter() - .filter_map(|peer| KademliaPeer::try_from(peer).ok()) - .collect(), - }) - }, - message => { - tracing::warn!(target: LOG_TARGET, ?message, "unhandled message"); - None - }, - }, - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to decode message"); - None - }, - } - } + /// Create `FIND_NODE` message for `peer`. + pub fn find_node>>(key: T) -> Bytes { + let message = schema::kademlia::Message { + key: key.into(), + r#type: schema::kademlia::MessageType::FindNode.into(), + cluster_level_raw: 10, + ..Default::default() + }; + + let mut buf = BytesMut::with_capacity(message.encoded_len()); + message.encode(&mut buf).expect("Vec to provide needed capacity"); + + buf.freeze() + } + + /// Create `PUT_VALUE` message for `record`. + // TODO: set ttl + pub fn put_value(record: Record) -> Bytes { + let message = schema::kademlia::Message { + key: record.key.clone().into(), + r#type: schema::kademlia::MessageType::PutValue.into(), + record: Some(schema::kademlia::Record { + key: record.key.into(), + value: record.value, + ..Default::default() + }), + cluster_level_raw: 10, + ..Default::default() + }; + + let mut buf = BytesMut::with_capacity(message.encoded_len()); + message.encode(&mut buf).expect("BytesMut to provide needed capacity"); + + buf.freeze() + } + + /// Create `GET_VALUE` message for `record`. + pub fn get_record(key: RecordKey) -> Bytes { + let message = schema::kademlia::Message { + key: key.clone().into(), + r#type: schema::kademlia::MessageType::GetValue.into(), + cluster_level_raw: 10, + ..Default::default() + }; + + let mut buf = BytesMut::with_capacity(message.encoded_len()); + message.encode(&mut buf).expect("BytesMut to provide needed capacity"); + + buf.freeze() + } + + /// Create `FIND_NODE` response. + pub fn find_node_response>(key: K, peers: Vec) -> Vec { + let message = schema::kademlia::Message { + key: key.as_ref().to_vec(), + cluster_level_raw: 10, + r#type: schema::kademlia::MessageType::FindNode.into(), + closer_peers: peers.iter().map(|peer| peer.into()).collect(), + ..Default::default() + }; + + let mut buf = Vec::with_capacity(message.encoded_len()); + message.encode(&mut buf).expect("Vec to provide needed capacity"); + + buf + } + + /// Create `PUT_VALUE` response. + pub fn get_value_response( + key: RecordKey, + peers: Vec, + record: Option, + ) -> Vec { + let message = schema::kademlia::Message { + key: key.to_vec(), + cluster_level_raw: 10, + r#type: schema::kademlia::MessageType::GetValue.into(), + closer_peers: peers.iter().map(|peer| peer.into()).collect(), + record: record.map(|record| schema::kademlia::Record { + key: record.key.to_vec(), + value: record.value, + ..Default::default() + }), + ..Default::default() + }; + + let mut buf = Vec::with_capacity(message.encoded_len()); + message.encode(&mut buf).expect("Vec to provide needed capacity"); + + buf + } + + /// Get [`KademliaMessage`] from bytes. + pub fn from_bytes(bytes: BytesMut) -> Option { + match schema::kademlia::Message::decode(bytes) { + Ok(message) => match message.r#type { + 4 => { + let peers = message + .closer_peers + .iter() + .filter_map(|peer| KademliaPeer::try_from(peer).ok()) + .collect(); + + Some(Self::FindNode { + target: message.key, + peers, + }) + } + 0 => { + let record = message.record?; + + Some(Self::PutValue { + record: Record::new(record.key, record.value), + }) + } + 1 => { + let key = match message.key.is_empty() { + true => message + .record + .as_ref() + .map(|record| { + (!record.key.is_empty()) + .then_some(RecordKey::from(record.key.clone())) + }) + .flatten(), + false => Some(RecordKey::from(message.key.clone())), + }; + + Some(Self::GetRecord { + key, + record: message.record.map(|record| Record::new(record.key, record.value)), + peers: message + .closer_peers + .iter() + .filter_map(|peer| KademliaPeer::try_from(peer).ok()) + .collect(), + }) + } + message => { + tracing::warn!(target: LOG_TARGET, ?message, "unhandled message"); + None + } + }, + Err(error) => { + tracing::debug!(target: LOG_TARGET, ?error, "failed to decode message"); + None + } + } + } } diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index 6279392a..662ef0e7 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -21,23 +21,23 @@ //! [`/ipfs/kad/1.0.0`](https://github.com/libp2p/specs/blob/master/kad-dht/README.md) implementation. use crate::{ - error::Error, - protocol::{ - libp2p::kademlia::{ - bucket::KBucketEntry, - executor::{QueryContext, QueryExecutor, QueryResult}, - handle::KademliaCommand, - message::KademliaMessage, - query::{QueryAction, QueryEngine}, - routing_table::RoutingTable, - store::MemoryStore, - types::{ConnectionType, KademliaPeer, Key}, - }, - Direction, TransportEvent, TransportService, - }, - substream::Substream, - types::SubstreamId, - PeerId, + error::Error, + protocol::{ + libp2p::kademlia::{ + bucket::KBucketEntry, + executor::{QueryContext, QueryExecutor, QueryResult}, + handle::KademliaCommand, + message::KademliaMessage, + query::{QueryAction, QueryEngine}, + routing_table::RoutingTable, + store::MemoryStore, + types::{ConnectionType, KademliaPeer, Key}, + }, + Direction, TransportEvent, TransportService, + }, + substream::Substream, + types::SubstreamId, + PeerId, }; use bytes::{Bytes, BytesMut}; @@ -70,763 +70,784 @@ mod store; mod types; mod schema { - pub(super) mod kademlia { - include!(concat!(env!("OUT_DIR"), "/kademlia.rs")); - } + pub(super) mod kademlia { + include!(concat!(env!("OUT_DIR"), "/kademlia.rs")); + } } /// Peer action. #[derive(Debug)] enum PeerAction { - /// Send `FIND_NODE` message to peer. - SendFindNode(QueryId), + /// Send `FIND_NODE` message to peer. + SendFindNode(QueryId), - /// Send `PUT_VALUE` message to peer. - SendPutValue(Bytes), + /// Send `PUT_VALUE` message to peer. + SendPutValue(Bytes), } /// Peer context. #[derive(Default)] struct PeerContext { - /// Pending action, if any. - pending_actions: HashMap, + /// Pending action, if any. + pending_actions: HashMap, } impl PeerContext { - /// Create new [`PeerContext`]. - pub fn new() -> Self { - Self { pending_actions: HashMap::new() } - } - - /// Add pending action for peer. - pub fn add_pending_action(&mut self, substream_id: SubstreamId, action: PeerAction) { - self.pending_actions.insert(substream_id, action); - } + /// Create new [`PeerContext`]. + pub fn new() -> Self { + Self { + pending_actions: HashMap::new(), + } + } + + /// Add pending action for peer. + pub fn add_pending_action(&mut self, substream_id: SubstreamId, action: PeerAction) { + self.pending_actions.insert(substream_id, action); + } } /// Main Kademlia object. pub(crate) struct Kademlia { - /// Transport service. - service: TransportService, + /// Transport service. + service: TransportService, - /// Local Kademlia key. - _local_key: Key, + /// Local Kademlia key. + _local_key: Key, - /// Connected peers, - peers: HashMap, + /// Connected peers, + peers: HashMap, - /// TX channel for sending events to `KademliaHandle`. - event_tx: Sender, + /// TX channel for sending events to `KademliaHandle`. + event_tx: Sender, - /// RX channel for receiving commands from `KademliaHandle`. - cmd_rx: Receiver, + /// RX channel for receiving commands from `KademliaHandle`. + cmd_rx: Receiver, - /// Routing table. - routing_table: RoutingTable, + /// Routing table. + routing_table: RoutingTable, - /// Replication factor. - replication_factor: usize, + /// Replication factor. + replication_factor: usize, - /// Record store. - store: MemoryStore, + /// Record store. + store: MemoryStore, - /// Pending outbound substreams. - pending_substreams: HashMap, + /// Pending outbound substreams. + pending_substreams: HashMap, - /// Pending dials. - pending_dials: HashMap>, + /// Pending dials. + pending_dials: HashMap>, - /// Routing table update mode. - update_mode: RoutingTableUpdateMode, + /// Routing table update mode. + update_mode: RoutingTableUpdateMode, - /// Query engine. - engine: QueryEngine, + /// Query engine. + engine: QueryEngine, - /// Query executor. - executor: QueryExecutor, + /// Query executor. + executor: QueryExecutor, } impl Kademlia { - /// Create new [`Kademlia`]. - pub(crate) fn new(mut service: TransportService, config: Config) -> Self { - let local_peer_id = service.local_peer_id; - let local_key = Key::from(service.local_peer_id); - let mut routing_table = RoutingTable::new(local_key.clone()); - - for (peer, addresses) in config.known_peers { - tracing::trace!(target: LOG_TARGET, ?peer, ?addresses, "add bootstrap peer"); - - routing_table.add_known_peer(peer, addresses.clone(), ConnectionType::NotConnected); - service.add_known_address(&peer, addresses.into_iter()); - } - - Self { - service, - routing_table, - peers: HashMap::new(), - cmd_rx: config.cmd_rx, - store: MemoryStore::new(), - event_tx: config.event_tx, - _local_key: local_key, - pending_dials: HashMap::new(), - executor: QueryExecutor::new(), - pending_substreams: HashMap::new(), - update_mode: config.update_mode, - replication_factor: config.replication_factor, - engine: QueryEngine::new(local_peer_id, config.replication_factor, PARALLELISM_FACTOR), - } - } - - /// Connection established to remote peer. - fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, "connection established"); - - match self.peers.entry(peer) { - Entry::Vacant(entry) => { - if let KBucketEntry::Occupied(entry) = self.routing_table.entry(Key::from(peer)) { - entry.connection = ConnectionType::Connected; - } - - let Some(actions) = self.pending_dials.remove(&peer) else { - entry.insert(PeerContext::new()); - return Ok(()); - }; - - // go over all pending actions, open substreams and save the state to `PeerContext` - // from which it will be later queried when the substream opens - let mut context = PeerContext::new(); - - for action in actions { - match self.service.open_substream(peer) { - Ok(substream_id) => { - context.add_pending_action(substream_id, action); - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?action, - ?error, - "connection established to peer but failed to open substream", - ); - - if let PeerAction::SendFindNode(query_id) = action { - self.engine.register_response_failure(query_id, peer); - } - }, - } - } - - entry.insert(context); - Ok(()) - }, - Entry::Occupied(_) => return Err(Error::PeerAlreadyExists(peer)), - } - } - - /// Disconnect peer from `Kademlia`. - /// - /// Peer is disconnected either because the substream was detected closed - /// or because the connection was closed. - /// - /// The peer is kept in the routing table but its connection state is set - /// as `NotConnected`, meaning it can be evicted from a k-bucket if another - /// peer that shares the bucket connects. - async fn disconnect_peer(&mut self, peer: PeerId, query: Option) { - tracing::trace!(target: LOG_TARGET, ?peer, ?query, "disconnect peer"); - - if let Some(query) = query { - self.engine.register_response_failure(query, peer); - } - - if let Some(PeerContext { pending_actions }) = self.peers.remove(&peer) { - pending_actions.into_iter().for_each(|(_, action)| { - if let PeerAction::SendFindNode(query_id) = action { - self.engine.register_response_failure(query_id, peer); - } - }); - } - - if let KBucketEntry::Occupied(entry) = self.routing_table.entry(Key::from(peer)) { - entry.connection = ConnectionType::NotConnected; - } - } - - /// Local node opened a substream to remote node. - async fn on_outbound_substream( - &mut self, - peer: PeerId, - substream_id: SubstreamId, - substream: Substream, - ) -> crate::Result<()> { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?substream_id, - "outbound substream opened", - ); - let _ = self.pending_substreams.remove(&substream_id); - - let pending_action = &mut self - .peers - .get_mut(&peer) - .ok_or(Error::PeerDoesntExist(peer))? - .pending_actions - .remove(&substream_id); - - match std::mem::replace(pending_action, None) { - None => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?substream_id, - "pending action doesn't exist for peer, closing substream", - ); - - let _ = substream.close().await; - return Ok(()); - }, - Some(PeerAction::SendFindNode(query)) => { - match self.engine.next_peer_action(&query, &peer) { - Some(QueryAction::SendMessage { query, peer, message }) => { - tracing::trace!(target: LOG_TARGET, ?peer, ?query, "start sending message to peer"); - - self.executor.send_request_read_response( - peer, - Some(query), - message, - substream, - ); - }, - // query finished while the substream was being opened - None => { - let _ = substream.close().await; - }, - action => { - tracing::warn!(target: LOG_TARGET, ?query, ?peer, ?action, "unexpected action for `FIND_NODE`"); - let _ = substream.close().await; - debug_assert!(false); - }, - } - }, - Some(PeerAction::SendPutValue(message)) => { - tracing::trace!(target: LOG_TARGET, ?peer, "send `PUT_VALUE` response"); - - self.executor.send_message(peer, message, substream); - }, - } - - Ok(()) - } - - /// Remote opened a substream to local node. - async fn on_inbound_substream(&mut self, peer: PeerId, substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "inbound substream opened"); - - self.executor.read_message(peer, None, substream); - } - - /// Update routing table if the routing table update mode was set to automatic. - /// - /// Inform user about the potential routing table, allowing them to update it manually if - /// the mode was set to manual. - async fn update_routing_table(&mut self, peers: &Vec) { - let peers: Vec<_> = peers - .iter() - .filter_map(|peer| (peer.peer != self.service.local_peer_id).then_some(peer)) - .collect(); - - // inform user about the routing table update, regardless of what the routing table update - // mode is - let _ = self - .event_tx - .send(KademliaEvent::RoutingTableUpdate { - peers: peers.iter().map(|peer| peer.peer).collect::>(), - }) - .await; - - for info in peers { - self.service.add_known_address(&info.peer, info.addresses.iter().cloned()); - - if std::matches!(self.update_mode, RoutingTableUpdateMode::Automatic) { - self.routing_table.add_known_peer( - info.peer, - info.addresses.clone(), - self.peers - .get(&info.peer) - .map_or(ConnectionType::NotConnected, |_| ConnectionType::Connected), - ); - } - } - } - - /// Handle received message. - async fn on_message_received( - &mut self, - peer: PeerId, - query_id: Option, - message: BytesMut, - substream: Substream, - ) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, ?query_id, "handle message from peer"); - - match KademliaMessage::from_bytes(message).ok_or(Error::InvalidData)? { - ref message @ KademliaMessage::FindNode { ref target, ref peers } => { - match query_id { - Some(query_id) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?target, - "handle `FIND_NODE` response", - ); - - // update routing table and inform user about the update - self.update_routing_table(peers).await; - self.engine.register_response(query_id, peer, message.clone()); - }, - None => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?target, - "handle `FIND_NODE` request", - ); - - let message = KademliaMessage::find_node_response( - target, - self.routing_table - .closest(Key::from(target.clone()), self.replication_factor), - ); - self.executor.send_message(peer, message.into(), substream); - }, - } - }, - KademliaMessage::PutValue { record } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - record_key = ?record.key, - "handle `PUT_VALUE` message", - ); - - self.store.put(record); - }, - ref message @ KademliaMessage::GetRecord { ref key, ref record, ref peers } => { - match (query_id, key) { - (Some(query_id), _) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?query_id, - ?peers, - ?record, - "handle `GET_VALUE` response", - ); - - // update routing table and inform user about the update - self.update_routing_table(peers).await; - self.engine.register_response(query_id, peer, message.clone()); - }, - (None, Some(key)) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?key, - "handle `GET_VALUE` request", - ); - - let value = self.store.get(key).map(|value| value.clone()); - let closest_peers = self - .routing_table - .closest(Key::from(key.to_vec()), self.replication_factor); - - let message = KademliaMessage::get_value_response( - (*key).clone(), - closest_peers, - value, - ); - self.executor.send_message(peer, message.into(), substream); - }, - (None, None) => tracing::debug!( - target: LOG_TARGET, - ?peer, - ?message, - "both query and record key missing, unable to handle message", - ), - } - }, - } - - Ok(()) - } - - /// Failed to open substream to remote peer. - async fn on_substream_open_failure(&mut self, substream_id: SubstreamId, error: Error) { - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - ?error, - "failed to open substream" - ); - - let Some(peer) = self.pending_substreams.remove(&substream_id) else { - tracing::debug!( - target: LOG_TARGET, - ?substream_id, - "outbound substream failed for non-existent peer" - ); - return; - }; - - if let Some(context) = self.peers.get_mut(&peer) { - let query = match context.pending_actions.remove(&substream_id) { - Some(PeerAction::SendFindNode(query)) => Some(query), - _ => None, - }; - - self.disconnect_peer(peer, query).await; - } - } - - /// Handle dial failure. - fn on_dial_failure(&mut self, peer: PeerId, address: Multiaddr) { - tracing::trace!(target: LOG_TARGET, ?peer, ?address, "failed to dial peer"); - - let Some(actions) = self.pending_dials.remove(&peer) else { - return; - }; - - for action in actions { - if let PeerAction::SendFindNode(query_id) = action { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?query_id, - ?address, - "report failure for pending query", - ); - - self.engine.register_response_failure(query_id, peer); - } - } - } - - /// Handle next query action. - async fn on_query_action(&mut self, action: QueryAction) -> Result<(), (QueryId, PeerId)> { - match action { - QueryAction::SendMessage { query, peer, .. } => match self.service.open_substream(peer) - { - Err(_) => { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "dial peer"); - - match self.service.dial(&peer) { - Ok(_) => match self.pending_dials.entry(peer) { - Entry::Occupied(entry) => { - entry.into_mut().push(PeerAction::SendFindNode(query)); - }, - Entry::Vacant(entry) => { - entry.insert(vec![PeerAction::SendFindNode(query)]); - }, - }, - Err(error) => { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, ?error, "failed to dial peer"); - self.engine.register_response_failure(query, peer); - }, - } - - Ok(()) - }, - Ok(substream_id) => { - tracing::trace!( - target: LOG_TARGET, - ?query, - ?peer, - ?substream_id, - "open outbound substream for peer" - ); - - self.pending_substreams.insert(substream_id, peer); - self.peers - .entry(peer) - .or_default() - .pending_actions - .insert(substream_id, PeerAction::SendFindNode(query)); - - Ok(()) - }, - }, - QueryAction::FindNodeQuerySucceeded { target, peers, query } => { - tracing::debug!( - target: LOG_TARGET, - ?query, - peer = ?target, - num_peers = ?peers.len(), - "`FIND_NODE` succeeded", - ); - - let _ = self - .event_tx - .send(KademliaEvent::FindNodeSuccess { - target, - query_id: query, - peers: peers.into_iter().map(|info| (info.peer, info.addresses)).collect(), - }) - .await; - Ok(()) - }, - QueryAction::PutRecordToFoundNodes { record, peers } => { - tracing::trace!( - target: LOG_TARGET, - record_key = ?record.key, - num_peers = ?peers.len(), - "store record to found peers", - ); - let key = record.key.clone(); - let message = KademliaMessage::put_value(record); - - for peer in peers { - match self.service.open_substream(peer.peer) { - Ok(substream_id) => { - self.pending_substreams.insert(substream_id, peer.peer); - self.peers - .entry(peer.peer) - .or_default() - .pending_actions - .insert(substream_id, PeerAction::SendPutValue(message.clone())); - }, - Err(_) => match self.service.dial(&peer.peer) { - Ok(_) => match self.pending_dials.entry(peer.peer) { - Entry::Occupied(entry) => { - entry - .into_mut() - .push(PeerAction::SendPutValue(message.clone())); - }, - Entry::Vacant(entry) => { - entry.insert(vec![PeerAction::SendPutValue(message.clone())]); - }, - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?key, - ?error, - "failed to dial peer", - ); - }, - }, - } - } - - Ok(()) - }, - QueryAction::GetRecordQueryDone { query_id, record } => { - self.store.put(record.clone()); - - let _ = - self.event_tx.send(KademliaEvent::GetRecordSuccess { query_id, record }).await; - Ok(()) - }, - QueryAction::QueryFailed { query } => { - tracing::debug!(target: LOG_TARGET, ?query, "query failed"); - - let _ = self.event_tx.send(KademliaEvent::QueryFailed { query_id: query }).await; - Ok(()) - }, - QueryAction::QuerySucceeded { .. } => unreachable!(), - } - } - - /// [`Kademlia`] event loop. - pub async fn run(mut self) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, "starting kademlia event loop"); - - loop { - // poll `QueryEngine` for next actions. - while let Some(action) = self.engine.next_action() { - if let Err((query, peer)) = self.on_query_action(action).await { - self.disconnect_peer(peer, Some(query)).await; - } - } - - tokio::select! { - event = self.service.next() => match event { - Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - if let Err(error) = self.on_connection_established(peer) { - tracing::debug!(target: LOG_TARGET, ?error, "failed to handle established connection"); - } - } - Some(TransportEvent::ConnectionClosed { peer }) => { - self.disconnect_peer(peer, None).await; - } - Some(TransportEvent::SubstreamOpened { peer, direction, substream, .. }) => { - match direction { - Direction::Inbound => self.on_inbound_substream(peer, substream).await, - Direction::Outbound(substream_id) => { - if let Err(error) = self.on_outbound_substream(peer, substream_id, substream).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?substream_id, - ?error, - "failed to handle outbound substream", - ); - } - } - } - }, - Some(TransportEvent::SubstreamOpenFailure { substream, error }) => { - self.on_substream_open_failure(substream, error).await; - } - Some(TransportEvent::DialFailure { peer, address }) => self.on_dial_failure(peer, address), - None => return Err(Error::EssentialTaskClosed), - }, - context = self.executor.next() => { - let QueryContext { peer, query_id, result } = context.unwrap(); - - match result { - QueryResult::SendSuccess { substream } => { - tracing::trace!(target: LOG_TARGET, ?peer, ?query_id, "message sent to peer"); - let _ = substream.close().await; - } - QueryResult::ReadSuccess { substream, message } => { - tracing::trace!(target: LOG_TARGET, ?peer, ?query_id, "message read from peer"); - - if let Err(error) = self.on_message_received(peer, query_id, message, substream).await { - tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to process message"); - } - } - QueryResult::SubstreamClosed | QueryResult::Timeout => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?query_id, - ?result, - "failed to read message from substream", - ); - - self.disconnect_peer(peer, query_id).await; - } - } - } - command = self.cmd_rx.recv() => { - match command { - Some(KademliaCommand::FindNode { peer, query_id }) => { - tracing::debug!(target: LOG_TARGET, ?peer, ?query_id, "starting `FIND_NODE` query"); - - self.engine.start_find_node( - query_id, - peer, - self.routing_table.closest(Key::from(peer), self.replication_factor).into() - ); - } - Some(KademliaCommand::PutRecord { record, query_id }) => { - tracing::debug!(target: LOG_TARGET, ?query_id, key = ?record.key, "store record to DHT"); - - self.store.put(record.clone()); - let key = Key::new(record.key.clone()); - - self.engine.start_put_record( - query_id, - record, - self.routing_table.closest(key, self.replication_factor).into(), - ); - } - Some(KademliaCommand::GetRecord { key, quorum, query_id }) => { - tracing::debug!(target: LOG_TARGET, ?key, "get record from DHT"); - - match (self.store.get(&key), quorum) { - (Some(record), Quorum::One) => { - let _ = self - .event_tx - .send(KademliaEvent::GetRecordSuccess { query_id, record: record.clone() }) - .await; - } - (record, _) => { - self.engine.start_get_record( - query_id, - key.clone(), - self.routing_table.closest(Key::new(key.clone()), self.replication_factor).into(), - quorum, - if record.is_some() { 1 } else { 0 }, - ); - } - } - - } - Some(KademliaCommand::AddKnownPeer { peer, addresses }) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?addresses, - "add known peer", - ); - - self.routing_table.add_known_peer( - peer, - addresses.clone(), - self.peers - .get(&peer) - .map_or(ConnectionType::NotConnected, |_| ConnectionType::Connected), - ); - self.service.add_known_address(&peer, addresses.into_iter()); - - } - None => return Err(Error::EssentialTaskClosed), - } - }, - } - } - } + /// Create new [`Kademlia`]. + pub(crate) fn new(mut service: TransportService, config: Config) -> Self { + let local_peer_id = service.local_peer_id; + let local_key = Key::from(service.local_peer_id); + let mut routing_table = RoutingTable::new(local_key.clone()); + + for (peer, addresses) in config.known_peers { + tracing::trace!(target: LOG_TARGET, ?peer, ?addresses, "add bootstrap peer"); + + routing_table.add_known_peer(peer, addresses.clone(), ConnectionType::NotConnected); + service.add_known_address(&peer, addresses.into_iter()); + } + + Self { + service, + routing_table, + peers: HashMap::new(), + cmd_rx: config.cmd_rx, + store: MemoryStore::new(), + event_tx: config.event_tx, + _local_key: local_key, + pending_dials: HashMap::new(), + executor: QueryExecutor::new(), + pending_substreams: HashMap::new(), + update_mode: config.update_mode, + replication_factor: config.replication_factor, + engine: QueryEngine::new(local_peer_id, config.replication_factor, PARALLELISM_FACTOR), + } + } + + /// Connection established to remote peer. + fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, "connection established"); + + match self.peers.entry(peer) { + Entry::Vacant(entry) => { + if let KBucketEntry::Occupied(entry) = self.routing_table.entry(Key::from(peer)) { + entry.connection = ConnectionType::Connected; + } + + let Some(actions) = self.pending_dials.remove(&peer) else { + entry.insert(PeerContext::new()); + return Ok(()); + }; + + // go over all pending actions, open substreams and save the state to `PeerContext` + // from which it will be later queried when the substream opens + let mut context = PeerContext::new(); + + for action in actions { + match self.service.open_substream(peer) { + Ok(substream_id) => { + context.add_pending_action(substream_id, action); + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?action, + ?error, + "connection established to peer but failed to open substream", + ); + + if let PeerAction::SendFindNode(query_id) = action { + self.engine.register_response_failure(query_id, peer); + } + } + } + } + + entry.insert(context); + Ok(()) + } + Entry::Occupied(_) => return Err(Error::PeerAlreadyExists(peer)), + } + } + + /// Disconnect peer from `Kademlia`. + /// + /// Peer is disconnected either because the substream was detected closed + /// or because the connection was closed. + /// + /// The peer is kept in the routing table but its connection state is set + /// as `NotConnected`, meaning it can be evicted from a k-bucket if another + /// peer that shares the bucket connects. + async fn disconnect_peer(&mut self, peer: PeerId, query: Option) { + tracing::trace!(target: LOG_TARGET, ?peer, ?query, "disconnect peer"); + + if let Some(query) = query { + self.engine.register_response_failure(query, peer); + } + + if let Some(PeerContext { pending_actions }) = self.peers.remove(&peer) { + pending_actions.into_iter().for_each(|(_, action)| { + if let PeerAction::SendFindNode(query_id) = action { + self.engine.register_response_failure(query_id, peer); + } + }); + } + + if let KBucketEntry::Occupied(entry) = self.routing_table.entry(Key::from(peer)) { + entry.connection = ConnectionType::NotConnected; + } + } + + /// Local node opened a substream to remote node. + async fn on_outbound_substream( + &mut self, + peer: PeerId, + substream_id: SubstreamId, + substream: Substream, + ) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?substream_id, + "outbound substream opened", + ); + let _ = self.pending_substreams.remove(&substream_id); + + let pending_action = &mut self + .peers + .get_mut(&peer) + .ok_or(Error::PeerDoesntExist(peer))? + .pending_actions + .remove(&substream_id); + + match std::mem::replace(pending_action, None) { + None => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?substream_id, + "pending action doesn't exist for peer, closing substream", + ); + + let _ = substream.close().await; + return Ok(()); + } + Some(PeerAction::SendFindNode(query)) => { + match self.engine.next_peer_action(&query, &peer) { + Some(QueryAction::SendMessage { + query, + peer, + message, + }) => { + tracing::trace!(target: LOG_TARGET, ?peer, ?query, "start sending message to peer"); + + self.executor.send_request_read_response( + peer, + Some(query), + message, + substream, + ); + } + // query finished while the substream was being opened + None => { + let _ = substream.close().await; + } + action => { + tracing::warn!(target: LOG_TARGET, ?query, ?peer, ?action, "unexpected action for `FIND_NODE`"); + let _ = substream.close().await; + debug_assert!(false); + } + } + } + Some(PeerAction::SendPutValue(message)) => { + tracing::trace!(target: LOG_TARGET, ?peer, "send `PUT_VALUE` response"); + + self.executor.send_message(peer, message, substream); + } + } + + Ok(()) + } + + /// Remote opened a substream to local node. + async fn on_inbound_substream(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "inbound substream opened"); + + self.executor.read_message(peer, None, substream); + } + + /// Update routing table if the routing table update mode was set to automatic. + /// + /// Inform user about the potential routing table, allowing them to update it manually if + /// the mode was set to manual. + async fn update_routing_table(&mut self, peers: &Vec) { + let peers: Vec<_> = peers + .iter() + .filter_map(|peer| (peer.peer != self.service.local_peer_id).then_some(peer)) + .collect(); + + // inform user about the routing table update, regardless of what the routing table update + // mode is + let _ = self + .event_tx + .send(KademliaEvent::RoutingTableUpdate { + peers: peers.iter().map(|peer| peer.peer).collect::>(), + }) + .await; + + for info in peers { + self.service.add_known_address(&info.peer, info.addresses.iter().cloned()); + + if std::matches!(self.update_mode, RoutingTableUpdateMode::Automatic) { + self.routing_table.add_known_peer( + info.peer, + info.addresses.clone(), + self.peers + .get(&info.peer) + .map_or(ConnectionType::NotConnected, |_| ConnectionType::Connected), + ); + } + } + } + + /// Handle received message. + async fn on_message_received( + &mut self, + peer: PeerId, + query_id: Option, + message: BytesMut, + substream: Substream, + ) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, ?query_id, "handle message from peer"); + + match KademliaMessage::from_bytes(message).ok_or(Error::InvalidData)? { + ref message @ KademliaMessage::FindNode { + ref target, + ref peers, + } => { + match query_id { + Some(query_id) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?target, + "handle `FIND_NODE` response", + ); + + // update routing table and inform user about the update + self.update_routing_table(peers).await; + self.engine.register_response(query_id, peer, message.clone()); + } + None => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?target, + "handle `FIND_NODE` request", + ); + + let message = KademliaMessage::find_node_response( + target, + self.routing_table + .closest(Key::from(target.clone()), self.replication_factor), + ); + self.executor.send_message(peer, message.into(), substream); + } + } + } + KademliaMessage::PutValue { record } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + record_key = ?record.key, + "handle `PUT_VALUE` message", + ); + + self.store.put(record); + } + ref message @ KademliaMessage::GetRecord { + ref key, + ref record, + ref peers, + } => { + match (query_id, key) { + (Some(query_id), _) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?query_id, + ?peers, + ?record, + "handle `GET_VALUE` response", + ); + + // update routing table and inform user about the update + self.update_routing_table(peers).await; + self.engine.register_response(query_id, peer, message.clone()); + } + (None, Some(key)) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?key, + "handle `GET_VALUE` request", + ); + + let value = self.store.get(key).map(|value| value.clone()); + let closest_peers = self + .routing_table + .closest(Key::from(key.to_vec()), self.replication_factor); + + let message = KademliaMessage::get_value_response( + (*key).clone(), + closest_peers, + value, + ); + self.executor.send_message(peer, message.into(), substream); + } + (None, None) => tracing::debug!( + target: LOG_TARGET, + ?peer, + ?message, + "both query and record key missing, unable to handle message", + ), + } + } + } + + Ok(()) + } + + /// Failed to open substream to remote peer. + async fn on_substream_open_failure(&mut self, substream_id: SubstreamId, error: Error) { + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + ?error, + "failed to open substream" + ); + + let Some(peer) = self.pending_substreams.remove(&substream_id) else { + tracing::debug!( + target: LOG_TARGET, + ?substream_id, + "outbound substream failed for non-existent peer" + ); + return; + }; + + if let Some(context) = self.peers.get_mut(&peer) { + let query = match context.pending_actions.remove(&substream_id) { + Some(PeerAction::SendFindNode(query)) => Some(query), + _ => None, + }; + + self.disconnect_peer(peer, query).await; + } + } + + /// Handle dial failure. + fn on_dial_failure(&mut self, peer: PeerId, address: Multiaddr) { + tracing::trace!(target: LOG_TARGET, ?peer, ?address, "failed to dial peer"); + + let Some(actions) = self.pending_dials.remove(&peer) else { + return; + }; + + for action in actions { + if let PeerAction::SendFindNode(query_id) = action { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?query_id, + ?address, + "report failure for pending query", + ); + + self.engine.register_response_failure(query_id, peer); + } + } + } + + /// Handle next query action. + async fn on_query_action(&mut self, action: QueryAction) -> Result<(), (QueryId, PeerId)> { + match action { + QueryAction::SendMessage { query, peer, .. } => match self.service.open_substream(peer) + { + Err(_) => { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "dial peer"); + + match self.service.dial(&peer) { + Ok(_) => match self.pending_dials.entry(peer) { + Entry::Occupied(entry) => { + entry.into_mut().push(PeerAction::SendFindNode(query)); + } + Entry::Vacant(entry) => { + entry.insert(vec![PeerAction::SendFindNode(query)]); + } + }, + Err(error) => { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, ?error, "failed to dial peer"); + self.engine.register_response_failure(query, peer); + } + } + + Ok(()) + } + Ok(substream_id) => { + tracing::trace!( + target: LOG_TARGET, + ?query, + ?peer, + ?substream_id, + "open outbound substream for peer" + ); + + self.pending_substreams.insert(substream_id, peer); + self.peers + .entry(peer) + .or_default() + .pending_actions + .insert(substream_id, PeerAction::SendFindNode(query)); + + Ok(()) + } + }, + QueryAction::FindNodeQuerySucceeded { + target, + peers, + query, + } => { + tracing::debug!( + target: LOG_TARGET, + ?query, + peer = ?target, + num_peers = ?peers.len(), + "`FIND_NODE` succeeded", + ); + + let _ = self + .event_tx + .send(KademliaEvent::FindNodeSuccess { + target, + query_id: query, + peers: peers.into_iter().map(|info| (info.peer, info.addresses)).collect(), + }) + .await; + Ok(()) + } + QueryAction::PutRecordToFoundNodes { record, peers } => { + tracing::trace!( + target: LOG_TARGET, + record_key = ?record.key, + num_peers = ?peers.len(), + "store record to found peers", + ); + let key = record.key.clone(); + let message = KademliaMessage::put_value(record); + + for peer in peers { + match self.service.open_substream(peer.peer) { + Ok(substream_id) => { + self.pending_substreams.insert(substream_id, peer.peer); + self.peers + .entry(peer.peer) + .or_default() + .pending_actions + .insert(substream_id, PeerAction::SendPutValue(message.clone())); + } + Err(_) => match self.service.dial(&peer.peer) { + Ok(_) => match self.pending_dials.entry(peer.peer) { + Entry::Occupied(entry) => { + entry + .into_mut() + .push(PeerAction::SendPutValue(message.clone())); + } + Entry::Vacant(entry) => { + entry.insert(vec![PeerAction::SendPutValue(message.clone())]); + } + }, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?key, + ?error, + "failed to dial peer", + ); + } + }, + } + } + + Ok(()) + } + QueryAction::GetRecordQueryDone { query_id, record } => { + self.store.put(record.clone()); + + let _ = + self.event_tx.send(KademliaEvent::GetRecordSuccess { query_id, record }).await; + Ok(()) + } + QueryAction::QueryFailed { query } => { + tracing::debug!(target: LOG_TARGET, ?query, "query failed"); + + let _ = self.event_tx.send(KademliaEvent::QueryFailed { query_id: query }).await; + Ok(()) + } + QueryAction::QuerySucceeded { .. } => unreachable!(), + } + } + + /// [`Kademlia`] event loop. + pub async fn run(mut self) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, "starting kademlia event loop"); + + loop { + // poll `QueryEngine` for next actions. + while let Some(action) = self.engine.next_action() { + if let Err((query, peer)) = self.on_query_action(action).await { + self.disconnect_peer(peer, Some(query)).await; + } + } + + tokio::select! { + event = self.service.next() => match event { + Some(TransportEvent::ConnectionEstablished { peer, .. }) => { + if let Err(error) = self.on_connection_established(peer) { + tracing::debug!(target: LOG_TARGET, ?error, "failed to handle established connection"); + } + } + Some(TransportEvent::ConnectionClosed { peer }) => { + self.disconnect_peer(peer, None).await; + } + Some(TransportEvent::SubstreamOpened { peer, direction, substream, .. }) => { + match direction { + Direction::Inbound => self.on_inbound_substream(peer, substream).await, + Direction::Outbound(substream_id) => { + if let Err(error) = self.on_outbound_substream(peer, substream_id, substream).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?substream_id, + ?error, + "failed to handle outbound substream", + ); + } + } + } + }, + Some(TransportEvent::SubstreamOpenFailure { substream, error }) => { + self.on_substream_open_failure(substream, error).await; + } + Some(TransportEvent::DialFailure { peer, address }) => self.on_dial_failure(peer, address), + None => return Err(Error::EssentialTaskClosed), + }, + context = self.executor.next() => { + let QueryContext { peer, query_id, result } = context.unwrap(); + + match result { + QueryResult::SendSuccess { substream } => { + tracing::trace!(target: LOG_TARGET, ?peer, ?query_id, "message sent to peer"); + let _ = substream.close().await; + } + QueryResult::ReadSuccess { substream, message } => { + tracing::trace!(target: LOG_TARGET, ?peer, ?query_id, "message read from peer"); + + if let Err(error) = self.on_message_received(peer, query_id, message, substream).await { + tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to process message"); + } + } + QueryResult::SubstreamClosed | QueryResult::Timeout => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?query_id, + ?result, + "failed to read message from substream", + ); + + self.disconnect_peer(peer, query_id).await; + } + } + } + command = self.cmd_rx.recv() => { + match command { + Some(KademliaCommand::FindNode { peer, query_id }) => { + tracing::debug!(target: LOG_TARGET, ?peer, ?query_id, "starting `FIND_NODE` query"); + + self.engine.start_find_node( + query_id, + peer, + self.routing_table.closest(Key::from(peer), self.replication_factor).into() + ); + } + Some(KademliaCommand::PutRecord { record, query_id }) => { + tracing::debug!(target: LOG_TARGET, ?query_id, key = ?record.key, "store record to DHT"); + + self.store.put(record.clone()); + let key = Key::new(record.key.clone()); + + self.engine.start_put_record( + query_id, + record, + self.routing_table.closest(key, self.replication_factor).into(), + ); + } + Some(KademliaCommand::GetRecord { key, quorum, query_id }) => { + tracing::debug!(target: LOG_TARGET, ?key, "get record from DHT"); + + match (self.store.get(&key), quorum) { + (Some(record), Quorum::One) => { + let _ = self + .event_tx + .send(KademliaEvent::GetRecordSuccess { query_id, record: record.clone() }) + .await; + } + (record, _) => { + self.engine.start_get_record( + query_id, + key.clone(), + self.routing_table.closest(Key::new(key.clone()), self.replication_factor).into(), + quorum, + if record.is_some() { 1 } else { 0 }, + ); + } + } + + } + Some(KademliaCommand::AddKnownPeer { peer, addresses }) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?addresses, + "add known peer", + ); + + self.routing_table.add_known_peer( + peer, + addresses.clone(), + self.peers + .get(&peer) + .map_or(ConnectionType::NotConnected, |_| ConnectionType::Connected), + ); + self.service.add_known_address(&peer, addresses.into_iter()); + + } + None => return Err(Error::EssentialTaskClosed), + } + }, + } + } + } } #[cfg(test)] mod tests { - use std::collections::HashSet; - - use super::*; - use crate::{ - codec::ProtocolCodec, crypto::ed25519::Keypair, transport::manager::TransportManager, - types::protocol::ProtocolName, BandwidthSink, - }; - use tokio::sync::mpsc::channel; - - #[allow(unused)] - struct Context { - _cmd_tx: Sender, - event_rx: Receiver, - } - - fn _make_kademlia() -> (Kademlia, Context, TransportManager) { - let (manager, handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - let peer = PeerId::random(); - let (transport_service, _tx) = TransportService::new( - peer, - ProtocolName::from("/kad/1"), - Vec::new(), - Default::default(), - handle, - ); - let (event_tx, event_rx) = channel(64); - let (_cmd_tx, cmd_rx) = channel(64); - - let config = Config { - protocol_names: vec![ProtocolName::from("/kad/1")], - known_peers: HashMap::new(), - codec: ProtocolCodec::UnsignedVarint(None), - replication_factor: 20usize, - update_mode: RoutingTableUpdateMode::Automatic, - event_tx, - cmd_rx, - }; - - (Kademlia::new(transport_service, config), Context { _cmd_tx, event_rx }, manager) - } + use std::collections::HashSet; + + use super::*; + use crate::{ + codec::ProtocolCodec, crypto::ed25519::Keypair, transport::manager::TransportManager, + types::protocol::ProtocolName, BandwidthSink, + }; + use tokio::sync::mpsc::channel; + + #[allow(unused)] + struct Context { + _cmd_tx: Sender, + event_rx: Receiver, + } + + fn _make_kademlia() -> (Kademlia, Context, TransportManager) { + let (manager, handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let peer = PeerId::random(); + let (transport_service, _tx) = TransportService::new( + peer, + ProtocolName::from("/kad/1"), + Vec::new(), + Default::default(), + handle, + ); + let (event_tx, event_rx) = channel(64); + let (_cmd_tx, cmd_rx) = channel(64); + + let config = Config { + protocol_names: vec![ProtocolName::from("/kad/1")], + known_peers: HashMap::new(), + codec: ProtocolCodec::UnsignedVarint(None), + replication_factor: 20usize, + update_mode: RoutingTableUpdateMode::Automatic, + event_tx, + cmd_rx, + }; + + ( + Kademlia::new(transport_service, config), + Context { _cmd_tx, event_rx }, + manager, + ) + } } diff --git a/src/protocol/libp2p/kademlia/query/find_node.rs b/src/protocol/libp2p/kademlia/query/find_node.rs index fb0e00ce..7d0bdb12 100644 --- a/src/protocol/libp2p/kademlia/query/find_node.rs +++ b/src/protocol/libp2p/kademlia/query/find_node.rs @@ -19,12 +19,12 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::libp2p::kademlia::{ - message::KademliaMessage, - query::{QueryAction, QueryId}, - types::{Distance, KademliaPeer, Key}, - }, - PeerId, + protocol::libp2p::kademlia::{ + message::KademliaMessage, + query::{QueryAction, QueryId}, + types::{Distance, KademliaPeer, Key}, + }, + PeerId, }; use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; @@ -35,202 +35,202 @@ const LOG_TARGET: &str = "litep2p::ipfs::kademlia::query::find_node"; /// Context for `FIND_NODE` queries. #[derive(Debug)] pub struct FindNodeContext>> { - /// Local peer ID. - local_peer_id: PeerId, + /// Local peer ID. + local_peer_id: PeerId, - /// Query ID. - pub query: QueryId, + /// Query ID. + pub query: QueryId, - /// Target key. - pub target: Key, + /// Target key. + pub target: Key, - /// Peers from whom the `QueryEngine` is waiting to hear a response. - pub pending: HashMap, + /// Peers from whom the `QueryEngine` is waiting to hear a response. + pub pending: HashMap, - /// Queried candidates. - /// - /// These are the peers for whom the query has already been sent - /// and who have either returned their closest peers or failed to answer. - pub queried: HashSet, + /// Queried candidates. + /// + /// These are the peers for whom the query has already been sent + /// and who have either returned their closest peers or failed to answer. + pub queried: HashSet, - /// Candidates. - pub candidates: BTreeMap, + /// Candidates. + pub candidates: BTreeMap, - /// Responses. - pub responses: BTreeMap, + /// Responses. + pub responses: BTreeMap, - /// Replication factor. - pub replication_factor: usize, + /// Replication factor. + pub replication_factor: usize, - /// Parallelism factor. - pub parallelism_factor: usize, + /// Parallelism factor. + pub parallelism_factor: usize, } impl>> FindNodeContext { - /// Create new [`FindNodeContext`]. - pub fn new( - local_peer_id: PeerId, - query: QueryId, - target: Key, - in_peers: VecDeque, - replication_factor: usize, - parallelism_factor: usize, - ) -> Self { - let mut candidates = BTreeMap::new(); - - for candidate in &in_peers { - let distance = target.distance(&candidate.key); - candidates.insert(distance, candidate.clone()); - } - - Self { - query, - target, - candidates, - local_peer_id, - pending: HashMap::new(), - queried: HashSet::new(), - responses: BTreeMap::new(), - replication_factor, - parallelism_factor, - } - } - - /// Register response failure for `peer`. - pub fn register_response_failure(&mut self, peer: PeerId) { - let Some(peer) = self.pending.remove(&peer) else { - tracing::debug!(target: LOG_TARGET, ?peer, "pending peer doesn't exist"); - return; - }; - - self.queried.insert(peer.peer); - } - - /// Register `FIND_NODE` response from `peer`. - pub fn register_response(&mut self, peer: PeerId, peers: Vec) { - let Some(peer) = self.pending.remove(&peer) else { - tracing::warn!(target: LOG_TARGET, ?peer, "received response from peer but didn't expect it"); - debug_assert!(false); - return; - }; - - // calculate distance for `peer` from target and insert it if - // a) the map doesn't have 20 responses - // b) it can replace some other peer that has a higher distance - let distance = self.target.distance(&peer.key); - - // always mark the peer as queried to prevent it getting queried again - self.queried.insert(peer.peer); - - // TODO: could this be written in another way? - // TODO: only insert nodes from whom a response was received - match self.responses.len() < self.replication_factor { - true => { - self.responses.insert(distance, peer); - }, - false => { - let mut entry = self.responses.last_entry().expect("entry to exist"); - if entry.key() > &distance { - entry.insert(peer); - } - }, - } - - // filter already queried peers and extend the set of candidates - for candidate in peers { - if !self.queried.contains(&candidate.peer) && - !self.pending.contains_key(&candidate.peer) - { - if self.local_peer_id == candidate.peer { - continue; - } - - let distance = self.target.distance(&candidate.key); - self.candidates.insert(distance, candidate); - } - } - } - - /// Get next action for `peer`. - pub fn next_peer_action(&mut self, peer: &PeerId) -> Option { - self.pending.contains_key(peer).then_some(QueryAction::SendMessage { - query: self.query, - peer: *peer, - message: KademliaMessage::find_node(self.target.clone().into_preimage()), - }) - } - - /// Schedule next peer for outbound `FIND_NODE` query. - pub fn schedule_next_peer(&mut self) -> QueryAction { - tracing::trace!(target: LOG_TARGET, query = ?self.query, "get next peer"); - - let (_, candidate) = self.candidates.pop_first().expect("entry to exist"); - self.pending.insert(candidate.peer, candidate.clone()); - - QueryAction::SendMessage { - query: self.query, - peer: candidate.peer, - message: KademliaMessage::find_node(self.target.clone().into_preimage()), - } - } - - /// Get next action for a `FIND_NODE` query. - // TODO: refactor this function - pub fn next_action(&mut self) -> Option { - // we didn't receive any responses and there are no candidates or pending queries left. - if self.responses.is_empty() && self.pending.is_empty() && self.candidates.is_empty() { - return Some(QueryAction::QueryFailed { query: self.query }); - } - - // there are still possible peers to query or peers who are being queried - if self.responses.len() < self.replication_factor && - (!self.pending.is_empty() || !self.candidates.is_empty()) - { - if self.pending.len() == self.parallelism_factor || self.candidates.is_empty() { - return None; - } - - return Some(self.schedule_next_peer()); - } - - // query succeeded with one or more results - if self.pending.is_empty() && self.candidates.is_empty() { - return Some(QueryAction::QuerySucceeded { query: self.query }); - } - - // check if any candidate has lower distance thant the current worst - // `expect()` is ok because both `candidates` and `responses` have been confirmed to contain - // entries - if !self.candidates.is_empty() { - let first_candidate_distance = self - .target - .distance(&self.candidates.first_key_value().expect("candidate to exist").1.key); - let worst_response_candidate = - self.responses.last_entry().expect("response to exist").key().clone(); - - if first_candidate_distance < worst_response_candidate && - self.pending.len() < self.parallelism_factor - { - return Some(self.schedule_next_peer()); - } - - return Some(QueryAction::QuerySucceeded { query: self.query }); - } - - if self.responses.len() == self.replication_factor { - return Some(QueryAction::QuerySucceeded { query: self.query }); - } - - tracing::error!( - target: LOG_TARGET, - candidates_len = ?self.candidates.len(), - pending_len = ?self.pending.len(), - responses_len = ?self.responses.len(), - "unhandled state" - ); - - unreachable!(); - } + /// Create new [`FindNodeContext`]. + pub fn new( + local_peer_id: PeerId, + query: QueryId, + target: Key, + in_peers: VecDeque, + replication_factor: usize, + parallelism_factor: usize, + ) -> Self { + let mut candidates = BTreeMap::new(); + + for candidate in &in_peers { + let distance = target.distance(&candidate.key); + candidates.insert(distance, candidate.clone()); + } + + Self { + query, + target, + candidates, + local_peer_id, + pending: HashMap::new(), + queried: HashSet::new(), + responses: BTreeMap::new(), + replication_factor, + parallelism_factor, + } + } + + /// Register response failure for `peer`. + pub fn register_response_failure(&mut self, peer: PeerId) { + let Some(peer) = self.pending.remove(&peer) else { + tracing::debug!(target: LOG_TARGET, ?peer, "pending peer doesn't exist"); + return; + }; + + self.queried.insert(peer.peer); + } + + /// Register `FIND_NODE` response from `peer`. + pub fn register_response(&mut self, peer: PeerId, peers: Vec) { + let Some(peer) = self.pending.remove(&peer) else { + tracing::warn!(target: LOG_TARGET, ?peer, "received response from peer but didn't expect it"); + debug_assert!(false); + return; + }; + + // calculate distance for `peer` from target and insert it if + // a) the map doesn't have 20 responses + // b) it can replace some other peer that has a higher distance + let distance = self.target.distance(&peer.key); + + // always mark the peer as queried to prevent it getting queried again + self.queried.insert(peer.peer); + + // TODO: could this be written in another way? + // TODO: only insert nodes from whom a response was received + match self.responses.len() < self.replication_factor { + true => { + self.responses.insert(distance, peer); + } + false => { + let mut entry = self.responses.last_entry().expect("entry to exist"); + if entry.key() > &distance { + entry.insert(peer); + } + } + } + + // filter already queried peers and extend the set of candidates + for candidate in peers { + if !self.queried.contains(&candidate.peer) + && !self.pending.contains_key(&candidate.peer) + { + if self.local_peer_id == candidate.peer { + continue; + } + + let distance = self.target.distance(&candidate.key); + self.candidates.insert(distance, candidate); + } + } + } + + /// Get next action for `peer`. + pub fn next_peer_action(&mut self, peer: &PeerId) -> Option { + self.pending.contains_key(peer).then_some(QueryAction::SendMessage { + query: self.query, + peer: *peer, + message: KademliaMessage::find_node(self.target.clone().into_preimage()), + }) + } + + /// Schedule next peer for outbound `FIND_NODE` query. + pub fn schedule_next_peer(&mut self) -> QueryAction { + tracing::trace!(target: LOG_TARGET, query = ?self.query, "get next peer"); + + let (_, candidate) = self.candidates.pop_first().expect("entry to exist"); + self.pending.insert(candidate.peer, candidate.clone()); + + QueryAction::SendMessage { + query: self.query, + peer: candidate.peer, + message: KademliaMessage::find_node(self.target.clone().into_preimage()), + } + } + + /// Get next action for a `FIND_NODE` query. + // TODO: refactor this function + pub fn next_action(&mut self) -> Option { + // we didn't receive any responses and there are no candidates or pending queries left. + if self.responses.is_empty() && self.pending.is_empty() && self.candidates.is_empty() { + return Some(QueryAction::QueryFailed { query: self.query }); + } + + // there are still possible peers to query or peers who are being queried + if self.responses.len() < self.replication_factor + && (!self.pending.is_empty() || !self.candidates.is_empty()) + { + if self.pending.len() == self.parallelism_factor || self.candidates.is_empty() { + return None; + } + + return Some(self.schedule_next_peer()); + } + + // query succeeded with one or more results + if self.pending.is_empty() && self.candidates.is_empty() { + return Some(QueryAction::QuerySucceeded { query: self.query }); + } + + // check if any candidate has lower distance thant the current worst + // `expect()` is ok because both `candidates` and `responses` have been confirmed to contain + // entries + if !self.candidates.is_empty() { + let first_candidate_distance = self + .target + .distance(&self.candidates.first_key_value().expect("candidate to exist").1.key); + let worst_response_candidate = + self.responses.last_entry().expect("response to exist").key().clone(); + + if first_candidate_distance < worst_response_candidate + && self.pending.len() < self.parallelism_factor + { + return Some(self.schedule_next_peer()); + } + + return Some(QueryAction::QuerySucceeded { query: self.query }); + } + + if self.responses.len() == self.replication_factor { + return Some(QueryAction::QuerySucceeded { query: self.query }); + } + + tracing::error!( + target: LOG_TARGET, + candidates_len = ?self.candidates.len(), + pending_len = ?self.pending.len(), + responses_len = ?self.responses.len(), + "unhandled state" + ); + + unreachable!(); + } } // TODO: tests diff --git a/src/protocol/libp2p/kademlia/query/get_record.rs b/src/protocol/libp2p/kademlia/query/get_record.rs index deaacc9c..eb9d8235 100644 --- a/src/protocol/libp2p/kademlia/query/get_record.rs +++ b/src/protocol/libp2p/kademlia/query/get_record.rs @@ -21,14 +21,14 @@ #![allow(unused)] use crate::{ - protocol::libp2p::kademlia::{ - message::KademliaMessage, - query::{QueryAction, QueryId}, - record::{Key as RecordKey, Record}, - types::{Distance, KademliaPeer, Key}, - Quorum, - }, - PeerId, + protocol::libp2p::kademlia::{ + message::KademliaMessage, + query::{QueryAction, QueryId}, + record::{Key as RecordKey, Record}, + types::{Distance, KademliaPeer, Key}, + Quorum, + }, + PeerId, }; use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; @@ -38,199 +38,199 @@ const LOG_TARGET: &str = "litep2p::ipfs::kademlia::query::get_record"; #[derive(Debug)] pub struct GetRecordContext { - /// Local peer ID. - local_peer_id: PeerId, + /// Local peer ID. + local_peer_id: PeerId, - /// How many records have been successfully found. - pub record_count: usize, + /// How many records have been successfully found. + pub record_count: usize, - /// Quorum for the query. - pub quorum: Quorum, + /// Quorum for the query. + pub quorum: Quorum, - /// Query ID. - pub query: QueryId, + /// Query ID. + pub query: QueryId, - /// Target key. - pub target: Key, + /// Target key. + pub target: Key, - /// Peers from whom the `QueryEngine` is waiting to hear a response. - pub pending: HashMap, + /// Peers from whom the `QueryEngine` is waiting to hear a response. + pub pending: HashMap, - /// Queried candidates. - /// - /// These are the peers for whom the query has already been sent - /// and who have either returned their closest peers or failed to answer. - pub queried: HashSet, + /// Queried candidates. + /// + /// These are the peers for whom the query has already been sent + /// and who have either returned their closest peers or failed to answer. + pub queried: HashSet, - /// Candidates. - pub candidates: BTreeMap, + /// Candidates. + pub candidates: BTreeMap, - /// Found records. - pub found_records: Vec, + /// Found records. + pub found_records: Vec, - /// Replication factor. - pub replication_factor: usize, + /// Replication factor. + pub replication_factor: usize, - /// Parallelism factor. - pub parallelism_factor: usize, + /// Parallelism factor. + pub parallelism_factor: usize, } impl GetRecordContext { - /// Create new [`GetRecordContext`]. - pub fn new( - local_peer_id: PeerId, - query: QueryId, - target: Key, - in_peers: VecDeque, - replication_factor: usize, - parallelism_factor: usize, - quorum: Quorum, - record_count: usize, - ) -> Self { - let mut candidates = BTreeMap::new(); - - for candidate in &in_peers { - let distance = target.distance(&candidate.key); - candidates.insert(distance, candidate.clone()); - } - - Self { - query, - target, - quorum, - candidates, - record_count, - local_peer_id, - replication_factor, - parallelism_factor, - pending: HashMap::new(), - queried: HashSet::new(), - found_records: Vec::new(), - } - } - - /// Get the found record. - pub fn found_record(mut self) -> Record { - self.found_records.pop().expect("record to exist since query succeeded") - } - - /// Register response failure for `peer`. - pub fn register_response_failure(&mut self, peer: PeerId) { - let Some(peer) = self.pending.remove(&peer) else { - tracing::trace!(target: LOG_TARGET, ?peer, "pending peer doesn't exist"); - return; - }; - - self.queried.insert(peer.peer); - } - - /// Register `GET_VALUE` response from `peer`. - pub fn register_response( - &mut self, - peer: PeerId, - record: Option, - peers: Vec, - ) { - let Some(peer) = self.pending.remove(&peer) else { - tracing::trace!(target: LOG_TARGET, ?peer, "received response from peer but didn't expect it"); - return; - }; - - // TODO: validate record - if let Some(record) = record { - self.found_records.push(record); - } - - // add the queried peer to `queried` and all new peers which haven't been - // queried to `candidates` - self.queried.insert(peer.peer); - - for candidate in peers { - if !self.queried.contains(&candidate.peer) && - !self.pending.contains_key(&candidate.peer) - { - if self.local_peer_id == candidate.peer { - continue; - } - - let distance = self.target.distance(&candidate.key); - self.candidates.insert(distance, candidate); - } - } - } - - /// Get next action for `peer`. - // TODO: remove this and store the next action to `PeerAction` - pub fn next_peer_action(&mut self, peer: &PeerId) -> Option { - self.pending.contains_key(peer).then_some(QueryAction::SendMessage { - query: self.query, - peer: *peer, - message: KademliaMessage::get_record(self.target.clone().into_preimage()), - }) - } - - /// Schedule next peer for outbound `GET_VALUE` query. - pub fn schedule_next_peer(&mut self) -> QueryAction { - tracing::trace!(target: LOG_TARGET, query = ?self.query, "get next peer"); - - let (_, candidate) = self.candidates.pop_first().expect("entry to exist"); - let peer = candidate.peer; - - tracing::trace!(target: LOG_TARGET, ?peer, "current candidate"); - self.pending.insert(candidate.peer, candidate); - - QueryAction::SendMessage { - query: self.query, - peer, - message: KademliaMessage::get_record(self.target.clone().into_preimage()), - } - } - - /// Get next action for a `GET_VALUE` query. - pub fn next_action(&mut self) -> Option { - // if there are no more peers to query, check if the query succeeded or failed - // the status is determined by whether a record was found - if self.pending.is_empty() && self.candidates.is_empty() { - match self.record_count + self.found_records.len() { - 0 => return Some(QueryAction::QueryFailed { query: self.query }), - _ => return Some(QueryAction::QuerySucceeded { query: self.query }), - } - } - - // check if enough records have been found - let continue_search = match self.quorum { - Quorum::All => (self.record_count + self.found_records.len() < self.replication_factor), - Quorum::One => (self.record_count + self.found_records.len() < 1), - Quorum::N(num_responses) => - (self.record_count + self.found_records.len() < num_responses.into()), - }; - - // if enough replicas for the record have been received (defined by the quorum size), - /// mark the query as succeeded - if !continue_search { - return Some(QueryAction::QuerySucceeded { query: self.query }); - } - - // if the search must continue, try to schedule next outbound message if possible - if !self.pending.is_empty() || !self.candidates.is_empty() { - if self.pending.len() == self.parallelism_factor || self.candidates.is_empty() { - return None; - } - - return Some(self.schedule_next_peer()); - } - - // TODO: probably not correct - tracing::warn!( - target: LOG_TARGET, - num_pending = ?self.pending.len(), - num_candidates = ?self.candidates.len(), - num_records = ?(self.record_count + self.found_records.len()), - quorum = ?self.quorum, - ?continue_search, - "unreachable condition for `GET_VALUE` search" - ); - - unreachable!(); - } + /// Create new [`GetRecordContext`]. + pub fn new( + local_peer_id: PeerId, + query: QueryId, + target: Key, + in_peers: VecDeque, + replication_factor: usize, + parallelism_factor: usize, + quorum: Quorum, + record_count: usize, + ) -> Self { + let mut candidates = BTreeMap::new(); + + for candidate in &in_peers { + let distance = target.distance(&candidate.key); + candidates.insert(distance, candidate.clone()); + } + + Self { + query, + target, + quorum, + candidates, + record_count, + local_peer_id, + replication_factor, + parallelism_factor, + pending: HashMap::new(), + queried: HashSet::new(), + found_records: Vec::new(), + } + } + + /// Get the found record. + pub fn found_record(mut self) -> Record { + self.found_records.pop().expect("record to exist since query succeeded") + } + + /// Register response failure for `peer`. + pub fn register_response_failure(&mut self, peer: PeerId) { + let Some(peer) = self.pending.remove(&peer) else { + tracing::trace!(target: LOG_TARGET, ?peer, "pending peer doesn't exist"); + return; + }; + + self.queried.insert(peer.peer); + } + + /// Register `GET_VALUE` response from `peer`. + pub fn register_response( + &mut self, + peer: PeerId, + record: Option, + peers: Vec, + ) { + let Some(peer) = self.pending.remove(&peer) else { + tracing::trace!(target: LOG_TARGET, ?peer, "received response from peer but didn't expect it"); + return; + }; + + // TODO: validate record + if let Some(record) = record { + self.found_records.push(record); + } + + // add the queried peer to `queried` and all new peers which haven't been + // queried to `candidates` + self.queried.insert(peer.peer); + + for candidate in peers { + if !self.queried.contains(&candidate.peer) + && !self.pending.contains_key(&candidate.peer) + { + if self.local_peer_id == candidate.peer { + continue; + } + + let distance = self.target.distance(&candidate.key); + self.candidates.insert(distance, candidate); + } + } + } + + /// Get next action for `peer`. + // TODO: remove this and store the next action to `PeerAction` + pub fn next_peer_action(&mut self, peer: &PeerId) -> Option { + self.pending.contains_key(peer).then_some(QueryAction::SendMessage { + query: self.query, + peer: *peer, + message: KademliaMessage::get_record(self.target.clone().into_preimage()), + }) + } + + /// Schedule next peer for outbound `GET_VALUE` query. + pub fn schedule_next_peer(&mut self) -> QueryAction { + tracing::trace!(target: LOG_TARGET, query = ?self.query, "get next peer"); + + let (_, candidate) = self.candidates.pop_first().expect("entry to exist"); + let peer = candidate.peer; + + tracing::trace!(target: LOG_TARGET, ?peer, "current candidate"); + self.pending.insert(candidate.peer, candidate); + + QueryAction::SendMessage { + query: self.query, + peer, + message: KademliaMessage::get_record(self.target.clone().into_preimage()), + } + } + + /// Get next action for a `GET_VALUE` query. + pub fn next_action(&mut self) -> Option { + // if there are no more peers to query, check if the query succeeded or failed + // the status is determined by whether a record was found + if self.pending.is_empty() && self.candidates.is_empty() { + match self.record_count + self.found_records.len() { + 0 => return Some(QueryAction::QueryFailed { query: self.query }), + _ => return Some(QueryAction::QuerySucceeded { query: self.query }), + } + } + + // check if enough records have been found + let continue_search = match self.quorum { + Quorum::All => (self.record_count + self.found_records.len() < self.replication_factor), + Quorum::One => (self.record_count + self.found_records.len() < 1), + Quorum::N(num_responses) => + (self.record_count + self.found_records.len() < num_responses.into()), + }; + + // if enough replicas for the record have been received (defined by the quorum size), + /// mark the query as succeeded + if !continue_search { + return Some(QueryAction::QuerySucceeded { query: self.query }); + } + + // if the search must continue, try to schedule next outbound message if possible + if !self.pending.is_empty() || !self.candidates.is_empty() { + if self.pending.len() == self.parallelism_factor || self.candidates.is_empty() { + return None; + } + + return Some(self.schedule_next_peer()); + } + + // TODO: probably not correct + tracing::warn!( + target: LOG_TARGET, + num_pending = ?self.pending.len(), + num_candidates = ?self.candidates.len(), + num_records = ?(self.record_count + self.found_records.len()), + quorum = ?self.quorum, + ?continue_search, + "unreachable condition for `GET_VALUE` search" + ); + + unreachable!(); + } } diff --git a/src/protocol/libp2p/kademlia/query/mod.rs b/src/protocol/libp2p/kademlia/query/mod.rs index dfb042b9..690ead05 100644 --- a/src/protocol/libp2p/kademlia/query/mod.rs +++ b/src/protocol/libp2p/kademlia/query/mod.rs @@ -19,14 +19,14 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::libp2p::kademlia::{ - message::KademliaMessage, - query::{find_node::FindNodeContext, get_record::GetRecordContext}, - record::{Key as RecordKey, Record}, - types::{KademliaPeer, Key}, - Quorum, - }, - PeerId, + protocol::libp2p::kademlia::{ + message::KademliaMessage, + query::{find_node::FindNodeContext, get_record::GetRecordContext}, + record::{Key as RecordKey, Record}, + types::{KademliaPeer, Key}, + Quorum, + }, + PeerId, }; use bytes::Bytes; @@ -48,599 +48,618 @@ pub struct QueryId(pub usize); /// Query type. #[derive(Debug)] enum QueryType { - /// `FIND_NODE` query. - FindNode { - /// Context for the `FIND_NODE` query - context: FindNodeContext, - }, - - /// `PUT_VALUE` query. - PutRecord { - /// Record that needs to be stored. - record: Record, - - /// Context for the `FIND_NODE` query - context: FindNodeContext, - }, - - /// `GET_VALUE` query. - GetRecord { - /// Context for the `GET_VALUE` query. - context: GetRecordContext, - }, + /// `FIND_NODE` query. + FindNode { + /// Context for the `FIND_NODE` query + context: FindNodeContext, + }, + + /// `PUT_VALUE` query. + PutRecord { + /// Record that needs to be stored. + record: Record, + + /// Context for the `FIND_NODE` query + context: FindNodeContext, + }, + + /// `GET_VALUE` query. + GetRecord { + /// Context for the `GET_VALUE` query. + context: GetRecordContext, + }, } /// Query action. #[derive(Debug)] pub enum QueryAction { - /// Send message to peer. - SendMessage { - /// Query ID. - query: QueryId, - - /// Peer. - peer: PeerId, - - /// Message. - message: Bytes, - }, - - /// `FIND_NODE` query succeeded. - FindNodeQuerySucceeded { - /// ID of the query that succeeded. - query: QueryId, - - /// Target peer. - target: PeerId, - - /// Peers that were found. - peers: Vec, - }, - - /// Store the record to nodest closest to target key. - // TODO: horrible name - PutRecordToFoundNodes { - /// Target peer. - record: Record, - - /// Peers for whom the `PUT_VALUE` must be sent to. - peers: Vec, - }, - - /// `GET_VALUE` query succeeded. - GetRecordQueryDone { - /// Query ID. - query_id: QueryId, - - /// Found record. - record: Record, - }, - - // TODO: remove - /// Query succeeded. - QuerySucceeded { - /// ID of the query that succeeded. - query: QueryId, - }, - - /// Query failed. - QueryFailed { - /// ID of the query that failed. - query: QueryId, - }, + /// Send message to peer. + SendMessage { + /// Query ID. + query: QueryId, + + /// Peer. + peer: PeerId, + + /// Message. + message: Bytes, + }, + + /// `FIND_NODE` query succeeded. + FindNodeQuerySucceeded { + /// ID of the query that succeeded. + query: QueryId, + + /// Target peer. + target: PeerId, + + /// Peers that were found. + peers: Vec, + }, + + /// Store the record to nodest closest to target key. + // TODO: horrible name + PutRecordToFoundNodes { + /// Target peer. + record: Record, + + /// Peers for whom the `PUT_VALUE` must be sent to. + peers: Vec, + }, + + /// `GET_VALUE` query succeeded. + GetRecordQueryDone { + /// Query ID. + query_id: QueryId, + + /// Found record. + record: Record, + }, + + // TODO: remove + /// Query succeeded. + QuerySucceeded { + /// ID of the query that succeeded. + query: QueryId, + }, + + /// Query failed. + QueryFailed { + /// ID of the query that failed. + query: QueryId, + }, } /// Kademlia query engine. pub struct QueryEngine { - /// Local peer ID. - local_peer_id: PeerId, + /// Local peer ID. + local_peer_id: PeerId, - /// Replication factor. - replication_factor: usize, + /// Replication factor. + replication_factor: usize, - /// Parallelism factor. - parallelism_factor: usize, + /// Parallelism factor. + parallelism_factor: usize, - /// Active queries. - queries: HashMap, + /// Active queries. + queries: HashMap, } impl QueryEngine { - /// Create new [`QueryEngine`]. - pub fn new( - local_peer_id: PeerId, - replication_factor: usize, - parallelism_factor: usize, - ) -> Self { - Self { local_peer_id, replication_factor, parallelism_factor, queries: HashMap::new() } - } - - /// Start `FIND_NODE` query. - pub fn start_find_node( - &mut self, - query_id: QueryId, - target: PeerId, - candidates: VecDeque, - ) -> QueryId { - tracing::debug!( - target: LOG_TARGET, - ?query_id, - ?target, - num_peers = ?candidates.len(), - "start `FIND_NODE` query" - ); - - self.queries.insert( - query_id, - QueryType::FindNode { - context: FindNodeContext::new( - self.local_peer_id, - query_id, - Key::from(target), - candidates, - self.replication_factor, - self.parallelism_factor, - ), - }, - ); - - query_id - } - - /// Start `PUT_VALUE` query. - pub fn start_put_record( - &mut self, - query_id: QueryId, - record: Record, - candidates: VecDeque, - ) -> QueryId { - tracing::debug!( - target: LOG_TARGET, - ?query_id, - target = ?record.key, - num_peers = ?candidates.len(), - "start `PUT_VALUE` query" - ); - - let target = Key::new(record.key.clone()); - - self.queries.insert( - query_id, - QueryType::PutRecord { - record, - context: FindNodeContext::new( - self.local_peer_id, - query_id, - target, - candidates, - self.replication_factor, - self.parallelism_factor, - ), - }, - ); - - query_id - } - - /// Start `GET_VALUE` query. - pub fn start_get_record( - &mut self, - query_id: QueryId, - target: RecordKey, - candidates: VecDeque, - quorum: Quorum, - count: usize, - ) -> QueryId { - tracing::debug!( - target: LOG_TARGET, - ?query_id, - ?target, - num_peers = ?candidates.len(), - "start `GET_VALUE` query" - ); - - let target = Key::new(target); - - self.queries.insert( - query_id, - QueryType::GetRecord { - context: GetRecordContext::new( - self.local_peer_id, - query_id, - target, - candidates, - self.replication_factor, - self.parallelism_factor, - quorum, - count, - ), - }, - ); - - query_id - } - - /// Register response failure from a queried peer. - pub fn register_response_failure(&mut self, query: QueryId, peer: PeerId) { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "register response failure"); - - match self.queries.get_mut(&query) { - None => { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "response failure for a stale query"); - return; - }, - Some(QueryType::FindNode { context }) => { - context.register_response_failure(peer); - }, - Some(QueryType::PutRecord { context, .. }) => { - context.register_response_failure(peer); - }, - Some(QueryType::GetRecord { context }) => { - context.register_response_failure(peer); - }, - } - } - - /// Register that `response` received from `peer`. - pub fn register_response(&mut self, query: QueryId, peer: PeerId, message: KademliaMessage) { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "register response"); - - match self.queries.get_mut(&query) { - None => { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "response failure for a stale query"); - return; - }, - Some(QueryType::FindNode { context }) => match message { - KademliaMessage::FindNode { peers, .. } => { - context.register_response(peer, peers); - }, - _ => unreachable!(), - }, - Some(QueryType::PutRecord { context, .. }) => match message { - KademliaMessage::FindNode { peers, .. } => { - context.register_response(peer, peers); - }, - _ => unreachable!(), - }, - Some(QueryType::GetRecord { context }) => match message { - KademliaMessage::GetRecord { record, peers, .. } => { - context.register_response(peer, record, peers); - }, - _ => unreachable!(), - }, - } - } - - /// Get next action for `peer` from the [`QueryEngine`]. - pub fn next_peer_action(&mut self, query: &QueryId, peer: &PeerId) -> Option { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "get next peer action"); - - match self.queries.get_mut(query) { - None => { - tracing::trace!(target: LOG_TARGET, ?query, ?peer, "response failure for a stale query"); - return None; - }, - Some(QueryType::FindNode { context }) => return context.next_peer_action(peer), - Some(QueryType::PutRecord { context, .. }) => return context.next_peer_action(peer), - Some(QueryType::GetRecord { context }) => return context.next_peer_action(peer), - } - } - - /// Handle query success by returning the queried value(s) - /// and removing the query from [`QueryEngine`]. - fn on_query_succeeded(&mut self, query: QueryId) -> QueryAction { - match self.queries.remove(&query).expect("query to exist") { - QueryType::FindNode { context } => QueryAction::FindNodeQuerySucceeded { - query, - target: context.target.into_preimage(), - peers: context.responses.into_iter().map(|(_, peer)| peer).collect::>(), - }, - QueryType::PutRecord { record, context } => QueryAction::PutRecordToFoundNodes { - record, - peers: context.responses.into_iter().map(|(_, peer)| peer).collect::>(), - }, - QueryType::GetRecord { context } => QueryAction::GetRecordQueryDone { - query_id: context.query, - record: context.found_record(), - }, - } - } - - /// Handle query failure by removing the query from [`QueryEngine`] and - /// returning the appropriate [`QueryAction`] to user. - fn on_query_failed(&mut self, query: QueryId) -> QueryAction { - let _ = self.queries.remove(&query).expect("query to exist"); - - QueryAction::QueryFailed { query } - } - - /// Get next action from the [`QueryEngine`]. - pub fn next_action(&mut self) -> Option { - for (_, state) in self.queries.iter_mut() { - let action = match state { - QueryType::FindNode { context } => context.next_action(), - QueryType::PutRecord { context, .. } => context.next_action(), - QueryType::GetRecord { context } => context.next_action(), - }; - - match action { - Some(QueryAction::QuerySucceeded { query }) => { - return Some(self.on_query_succeeded(query)); - }, - Some(QueryAction::QueryFailed { query }) => - return Some(self.on_query_failed(query)), - Some(_) => return action, - _ => continue, - } - } - - None - } + /// Create new [`QueryEngine`]. + pub fn new( + local_peer_id: PeerId, + replication_factor: usize, + parallelism_factor: usize, + ) -> Self { + Self { + local_peer_id, + replication_factor, + parallelism_factor, + queries: HashMap::new(), + } + } + + /// Start `FIND_NODE` query. + pub fn start_find_node( + &mut self, + query_id: QueryId, + target: PeerId, + candidates: VecDeque, + ) -> QueryId { + tracing::debug!( + target: LOG_TARGET, + ?query_id, + ?target, + num_peers = ?candidates.len(), + "start `FIND_NODE` query" + ); + + self.queries.insert( + query_id, + QueryType::FindNode { + context: FindNodeContext::new( + self.local_peer_id, + query_id, + Key::from(target), + candidates, + self.replication_factor, + self.parallelism_factor, + ), + }, + ); + + query_id + } + + /// Start `PUT_VALUE` query. + pub fn start_put_record( + &mut self, + query_id: QueryId, + record: Record, + candidates: VecDeque, + ) -> QueryId { + tracing::debug!( + target: LOG_TARGET, + ?query_id, + target = ?record.key, + num_peers = ?candidates.len(), + "start `PUT_VALUE` query" + ); + + let target = Key::new(record.key.clone()); + + self.queries.insert( + query_id, + QueryType::PutRecord { + record, + context: FindNodeContext::new( + self.local_peer_id, + query_id, + target, + candidates, + self.replication_factor, + self.parallelism_factor, + ), + }, + ); + + query_id + } + + /// Start `GET_VALUE` query. + pub fn start_get_record( + &mut self, + query_id: QueryId, + target: RecordKey, + candidates: VecDeque, + quorum: Quorum, + count: usize, + ) -> QueryId { + tracing::debug!( + target: LOG_TARGET, + ?query_id, + ?target, + num_peers = ?candidates.len(), + "start `GET_VALUE` query" + ); + + let target = Key::new(target); + + self.queries.insert( + query_id, + QueryType::GetRecord { + context: GetRecordContext::new( + self.local_peer_id, + query_id, + target, + candidates, + self.replication_factor, + self.parallelism_factor, + quorum, + count, + ), + }, + ); + + query_id + } + + /// Register response failure from a queried peer. + pub fn register_response_failure(&mut self, query: QueryId, peer: PeerId) { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "register response failure"); + + match self.queries.get_mut(&query) { + None => { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "response failure for a stale query"); + return; + } + Some(QueryType::FindNode { context }) => { + context.register_response_failure(peer); + } + Some(QueryType::PutRecord { context, .. }) => { + context.register_response_failure(peer); + } + Some(QueryType::GetRecord { context }) => { + context.register_response_failure(peer); + } + } + } + + /// Register that `response` received from `peer`. + pub fn register_response(&mut self, query: QueryId, peer: PeerId, message: KademliaMessage) { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "register response"); + + match self.queries.get_mut(&query) { + None => { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "response failure for a stale query"); + return; + } + Some(QueryType::FindNode { context }) => match message { + KademliaMessage::FindNode { peers, .. } => { + context.register_response(peer, peers); + } + _ => unreachable!(), + }, + Some(QueryType::PutRecord { context, .. }) => match message { + KademliaMessage::FindNode { peers, .. } => { + context.register_response(peer, peers); + } + _ => unreachable!(), + }, + Some(QueryType::GetRecord { context }) => match message { + KademliaMessage::GetRecord { record, peers, .. } => { + context.register_response(peer, record, peers); + } + _ => unreachable!(), + }, + } + } + + /// Get next action for `peer` from the [`QueryEngine`]. + pub fn next_peer_action(&mut self, query: &QueryId, peer: &PeerId) -> Option { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "get next peer action"); + + match self.queries.get_mut(query) { + None => { + tracing::trace!(target: LOG_TARGET, ?query, ?peer, "response failure for a stale query"); + return None; + } + Some(QueryType::FindNode { context }) => return context.next_peer_action(peer), + Some(QueryType::PutRecord { context, .. }) => return context.next_peer_action(peer), + Some(QueryType::GetRecord { context }) => return context.next_peer_action(peer), + } + } + + /// Handle query success by returning the queried value(s) + /// and removing the query from [`QueryEngine`]. + fn on_query_succeeded(&mut self, query: QueryId) -> QueryAction { + match self.queries.remove(&query).expect("query to exist") { + QueryType::FindNode { context } => QueryAction::FindNodeQuerySucceeded { + query, + target: context.target.into_preimage(), + peers: context.responses.into_iter().map(|(_, peer)| peer).collect::>(), + }, + QueryType::PutRecord { record, context } => QueryAction::PutRecordToFoundNodes { + record, + peers: context.responses.into_iter().map(|(_, peer)| peer).collect::>(), + }, + QueryType::GetRecord { context } => QueryAction::GetRecordQueryDone { + query_id: context.query, + record: context.found_record(), + }, + } + } + + /// Handle query failure by removing the query from [`QueryEngine`] and + /// returning the appropriate [`QueryAction`] to user. + fn on_query_failed(&mut self, query: QueryId) -> QueryAction { + let _ = self.queries.remove(&query).expect("query to exist"); + + QueryAction::QueryFailed { query } + } + + /// Get next action from the [`QueryEngine`]. + pub fn next_action(&mut self) -> Option { + for (_, state) in self.queries.iter_mut() { + let action = match state { + QueryType::FindNode { context } => context.next_action(), + QueryType::PutRecord { context, .. } => context.next_action(), + QueryType::GetRecord { context } => context.next_action(), + }; + + match action { + Some(QueryAction::QuerySucceeded { query }) => { + return Some(self.on_query_succeeded(query)); + } + Some(QueryAction::QueryFailed { query }) => + return Some(self.on_query_failed(query)), + Some(_) => return action, + _ => continue, + } + } + + None + } } #[cfg(test)] mod tests { - use multihash::{Code, Multihash}; - - use super::*; - use crate::protocol::libp2p::kademlia::types::ConnectionType; - - // make fixed peer id - fn make_peer_id(first: u8, second: u8) -> PeerId { - let mut peer_id = vec![0u8; 32]; - peer_id[0] = first; - peer_id[1] = second; - - PeerId::from_bytes( - &Multihash::wrap(Code::Identity.into(), &peer_id) - .expect("The digest size is never too large") - .to_bytes(), - ) - .unwrap() - } - - #[test] - fn query_fails() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); - let target_peer = PeerId::random(); - let _target_key = Key::from(target_peer); - - let query = engine.start_find_node( - QueryId(1337), - target_peer, - vec![ - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - ] - .into(), - ); - - for _ in 0..4 { - if let Some(QueryAction::SendMessage { query, peer, .. }) = engine.next_action() { - engine.register_response_failure(query, peer); - } - } - - if let Some(QueryAction::QueryFailed { query: failed }) = engine.next_action() { - assert_eq!(failed, query); - } - - assert!(engine.next_action().is_none()); - } - - #[test] - fn lookup_paused() { - let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); - let target_peer = PeerId::random(); - let _target_key = Key::from(target_peer); - - let _ = engine.start_find_node( - QueryId(1338), - target_peer, - vec![ - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), - ] - .into(), - ); - - for _ in 0..3 { - let _ = engine.next_action(); - } - - assert!(engine.next_action().is_none()); - } - - #[test] - fn find_node_query_succeeds() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); - let target_peer = make_peer_id(0, 0); - let target_key = Key::from(target_peer); - - let distances = { - let mut distances = std::collections::BTreeMap::new(); - - for i in 1..64 { - let peer = make_peer_id(i, 0); - let key = Key::from(peer); - - distances.insert(target_key.distance(&key), peer); - } - - distances - }; - let mut iter = distances.iter(); - - // start find node with one known peer - let _query = engine.start_find_node( - QueryId(1339), - target_peer, - vec![KademliaPeer::new(*iter.next().unwrap().1, vec![], ConnectionType::NotConnected)] - .into(), - ); - - let action = engine.next_action(); - assert!(engine.next_action().is_none()); - - // the one known peer responds with 3 other peers it knows - match action { - Some(QueryAction::SendMessage { query, peer, .. }) => { - engine.register_response( - query, - peer, - KademliaMessage::FindNode { - target: Vec::new(), - peers: vec![ - KademliaPeer::new( - *iter.next().unwrap().1, - vec![], - ConnectionType::NotConnected, - ), - KademliaPeer::new( - *iter.next().unwrap().1, - vec![], - ConnectionType::NotConnected, - ), - KademliaPeer::new( - *iter.next().unwrap().1, - vec![], - ConnectionType::NotConnected, - ), - ], - }, - ); - }, - _ => panic!("invalid event received"), - } - - // send empty response for the last three nodes - for _ in 0..3 { - match engine.next_action() { - Some(QueryAction::SendMessage { query, peer, .. }) => { - println!("next send message to {peer:?}"); - engine.register_response( - query, - peer, - KademliaMessage::FindNode { target: Vec::new(), peers: vec![] }, - ); - }, - _ => panic!("invalid event received"), - } - } - - match engine.next_action() { - Some(QueryAction::FindNodeQuerySucceeded { peers, .. }) => { - assert_eq!(peers.len(), 4); - }, - _ => panic!("invalid event received"), - } - - assert!(engine.next_action().is_none()); - } - - #[test] - fn put_record_succeeds() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); - let record_key = RecordKey::new(&vec![1, 2, 3, 4]); - let target_key = Key::new(record_key.clone()); - let original_record = Record::new(record_key, vec![1, 3, 3, 7, 1, 3, 3, 8]); - - let distances = { - let mut distances = std::collections::BTreeMap::new(); - - for i in 1..64 { - let peer = make_peer_id(i, 0); - let key = Key::from(peer); - - distances.insert(target_key.distance(&key), peer); - } - - distances - }; - let mut iter = distances.iter(); - - // start find node with one known peer - let _query = engine.start_put_record( - QueryId(1340), - original_record.clone(), - vec![KademliaPeer::new(*iter.next().unwrap().1, vec![], ConnectionType::NotConnected)] - .into(), - ); - - let action = engine.next_action(); - assert!(engine.next_action().is_none()); - - // the one known peer responds with 3 other peers it knows - match action { - Some(QueryAction::SendMessage { query, peer, .. }) => { - engine.register_response( - query, - peer, - KademliaMessage::FindNode { - target: Vec::new(), - peers: vec![ - KademliaPeer::new( - *iter.next().unwrap().1, - vec![], - ConnectionType::NotConnected, - ), - KademliaPeer::new( - *iter.next().unwrap().1, - vec![], - ConnectionType::NotConnected, - ), - KademliaPeer::new( - *iter.next().unwrap().1, - vec![], - ConnectionType::NotConnected, - ), - ], - }, - ); - }, - _ => panic!("invalid event received"), - } - - // send empty response for the last three nodes - for _ in 0..3 { - match engine.next_action() { - Some(QueryAction::SendMessage { query, peer, .. }) => { - println!("next send message to {peer:?}"); - engine.register_response( - query, - peer, - KademliaMessage::FindNode { target: Vec::new(), peers: vec![] }, - ); - }, - _ => panic!("invalid event received"), - } - } - - match engine.next_action() { - Some(QueryAction::PutRecordToFoundNodes { peers, record }) => { - assert_eq!(peers.len(), 4); - assert_eq!(record.key, original_record.key); - assert_eq!(record.value, original_record.value); - }, - _ => panic!("invalid event received"), - } - - assert!(engine.next_action().is_none()); - } + use multihash::{Code, Multihash}; + + use super::*; + use crate::protocol::libp2p::kademlia::types::ConnectionType; + + // make fixed peer id + fn make_peer_id(first: u8, second: u8) -> PeerId { + let mut peer_id = vec![0u8; 32]; + peer_id[0] = first; + peer_id[1] = second; + + PeerId::from_bytes( + &Multihash::wrap(Code::Identity.into(), &peer_id) + .expect("The digest size is never too large") + .to_bytes(), + ) + .unwrap() + } + + #[test] + fn query_fails() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); + let target_peer = PeerId::random(); + let _target_key = Key::from(target_peer); + + let query = engine.start_find_node( + QueryId(1337), + target_peer, + vec![ + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + ] + .into(), + ); + + for _ in 0..4 { + if let Some(QueryAction::SendMessage { query, peer, .. }) = engine.next_action() { + engine.register_response_failure(query, peer); + } + } + + if let Some(QueryAction::QueryFailed { query: failed }) = engine.next_action() { + assert_eq!(failed, query); + } + + assert!(engine.next_action().is_none()); + } + + #[test] + fn lookup_paused() { + let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); + let target_peer = PeerId::random(); + let _target_key = Key::from(target_peer); + + let _ = engine.start_find_node( + QueryId(1338), + target_peer, + vec![ + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + KademliaPeer::new(PeerId::random(), vec![], ConnectionType::NotConnected), + ] + .into(), + ); + + for _ in 0..3 { + let _ = engine.next_action(); + } + + assert!(engine.next_action().is_none()); + } + + #[test] + fn find_node_query_succeeds() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); + let target_peer = make_peer_id(0, 0); + let target_key = Key::from(target_peer); + + let distances = { + let mut distances = std::collections::BTreeMap::new(); + + for i in 1..64 { + let peer = make_peer_id(i, 0); + let key = Key::from(peer); + + distances.insert(target_key.distance(&key), peer); + } + + distances + }; + let mut iter = distances.iter(); + + // start find node with one known peer + let _query = engine.start_find_node( + QueryId(1339), + target_peer, + vec![KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + )] + .into(), + ); + + let action = engine.next_action(); + assert!(engine.next_action().is_none()); + + // the one known peer responds with 3 other peers it knows + match action { + Some(QueryAction::SendMessage { query, peer, .. }) => { + engine.register_response( + query, + peer, + KademliaMessage::FindNode { + target: Vec::new(), + peers: vec![ + KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + ), + KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + ), + KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + ), + ], + }, + ); + } + _ => panic!("invalid event received"), + } + + // send empty response for the last three nodes + for _ in 0..3 { + match engine.next_action() { + Some(QueryAction::SendMessage { query, peer, .. }) => { + println!("next send message to {peer:?}"); + engine.register_response( + query, + peer, + KademliaMessage::FindNode { + target: Vec::new(), + peers: vec![], + }, + ); + } + _ => panic!("invalid event received"), + } + } + + match engine.next_action() { + Some(QueryAction::FindNodeQuerySucceeded { peers, .. }) => { + assert_eq!(peers.len(), 4); + } + _ => panic!("invalid event received"), + } + + assert!(engine.next_action().is_none()); + } + + #[test] + fn put_record_succeeds() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut engine = QueryEngine::new(PeerId::random(), 20usize, 3usize); + let record_key = RecordKey::new(&vec![1, 2, 3, 4]); + let target_key = Key::new(record_key.clone()); + let original_record = Record::new(record_key, vec![1, 3, 3, 7, 1, 3, 3, 8]); + + let distances = { + let mut distances = std::collections::BTreeMap::new(); + + for i in 1..64 { + let peer = make_peer_id(i, 0); + let key = Key::from(peer); + + distances.insert(target_key.distance(&key), peer); + } + + distances + }; + let mut iter = distances.iter(); + + // start find node with one known peer + let _query = engine.start_put_record( + QueryId(1340), + original_record.clone(), + vec![KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + )] + .into(), + ); + + let action = engine.next_action(); + assert!(engine.next_action().is_none()); + + // the one known peer responds with 3 other peers it knows + match action { + Some(QueryAction::SendMessage { query, peer, .. }) => { + engine.register_response( + query, + peer, + KademliaMessage::FindNode { + target: Vec::new(), + peers: vec![ + KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + ), + KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + ), + KademliaPeer::new( + *iter.next().unwrap().1, + vec![], + ConnectionType::NotConnected, + ), + ], + }, + ); + } + _ => panic!("invalid event received"), + } + + // send empty response for the last three nodes + for _ in 0..3 { + match engine.next_action() { + Some(QueryAction::SendMessage { query, peer, .. }) => { + println!("next send message to {peer:?}"); + engine.register_response( + query, + peer, + KademliaMessage::FindNode { + target: Vec::new(), + peers: vec![], + }, + ); + } + _ => panic!("invalid event received"), + } + } + + match engine.next_action() { + Some(QueryAction::PutRecordToFoundNodes { peers, record }) => { + assert_eq!(peers.len(), 4); + assert_eq!(record.key, original_record.key); + assert_eq!(record.value, original_record.value); + } + _ => panic!("invalid event received"), + } + + assert!(engine.next_action().is_none()); + } } diff --git a/src/protocol/libp2p/kademlia/record.rs b/src/protocol/libp2p/kademlia/record.rs index 490b38d3..619f6ebf 100644 --- a/src/protocol/libp2p/kademlia/record.rs +++ b/src/protocol/libp2p/kademlia/record.rs @@ -32,74 +32,79 @@ use std::{borrow::Borrow, time::Instant}; pub struct Key(Bytes); impl Key { - /// Creates a new key from the bytes of the input. - pub fn new>(key: &K) -> Self { - Key(Bytes::copy_from_slice(key.as_ref())) - } - - /// Copies the bytes of the key into a new vector. - pub fn to_vec(&self) -> Vec { - Vec::from(&self.0[..]) - } + /// Creates a new key from the bytes of the input. + pub fn new>(key: &K) -> Self { + Key(Bytes::copy_from_slice(key.as_ref())) + } + + /// Copies the bytes of the key into a new vector. + pub fn to_vec(&self) -> Vec { + Vec::from(&self.0[..]) + } } impl Into> for Key { - fn into(self) -> Vec { - Vec::from(&self.0[..]) - } + fn into(self) -> Vec { + Vec::from(&self.0[..]) + } } impl Borrow<[u8]> for Key { - fn borrow(&self) -> &[u8] { - &self.0[..] - } + fn borrow(&self) -> &[u8] { + &self.0[..] + } } impl AsRef<[u8]> for Key { - fn as_ref(&self) -> &[u8] { - &self.0[..] - } + fn as_ref(&self) -> &[u8] { + &self.0[..] + } } impl From> for Key { - fn from(v: Vec) -> Key { - Key(Bytes::from(v)) - } + fn from(v: Vec) -> Key { + Key(Bytes::from(v)) + } } impl From for Key { - fn from(m: Multihash) -> Key { - Key::from(m.to_bytes()) - } + fn from(m: Multihash) -> Key { + Key::from(m.to_bytes()) + } } /// A record stored in the DHT. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Record { - /// Key of the record. - pub key: Key, + /// Key of the record. + pub key: Key, - /// Value of the record. - pub value: Vec, + /// Value of the record. + pub value: Vec, - /// The (original) publisher of the record. - pub publisher: Option, + /// The (original) publisher of the record. + pub publisher: Option, - /// The expiration time as measured by a local, monotonic clock. - pub expires: Option, + /// The expiration time as measured by a local, monotonic clock. + pub expires: Option, } impl Record { - /// Creates a new record for insertion into the DHT. - pub fn new(key: K, value: Vec) -> Self - where - K: Into, - { - Record { key: key.into(), value, publisher: None, expires: None } - } - - /// Checks whether the record is expired w.r.t. the given `Instant`. - pub fn _is_expired(&self, now: Instant) -> bool { - self.expires.map_or(false, |t| now >= t) - } + /// Creates a new record for insertion into the DHT. + pub fn new(key: K, value: Vec) -> Self + where + K: Into, + { + Record { + key: key.into(), + value, + publisher: None, + expires: None, + } + } + + /// Checks whether the record is expired w.r.t. the given `Instant`. + pub fn _is_expired(&self, now: Instant) -> bool { + self.expires.map_or(false, |t| now >= t) + } } diff --git a/src/protocol/libp2p/kademlia/routing_table.rs b/src/protocol/libp2p/kademlia/routing_table.rs index 95d69f18..28dd251e 100644 --- a/src/protocol/libp2p/kademlia/routing_table.rs +++ b/src/protocol/libp2p/kademlia/routing_table.rs @@ -22,11 +22,11 @@ //! Kademlia routing table implementation. use crate::{ - protocol::libp2p::kademlia::{ - bucket::{KBucket, KBucketEntry}, - types::{ConnectionType, Distance, KademliaPeer, Key, U256}, - }, - PeerId, + protocol::libp2p::kademlia::{ + bucket::{KBucket, KBucketEntry}, + types::{ConnectionType, Distance, KademliaPeer, Key, U256}, + }, + PeerId, }; use multiaddr::{Multiaddr, Protocol}; @@ -39,11 +39,11 @@ const NUM_BUCKETS: usize = 256; const LOG_TARGET: &str = "litep2p::ipfs::kademlia::routing_table"; pub struct RoutingTable { - /// Local key. - local_key: Key, + /// Local key. + local_key: Key, - /// K-buckets. - buckets: Vec, + /// K-buckets. + buckets: Vec, } /// A (type-safe) index into a `KBucketsTable`, i.e. a non-negative integer in the @@ -52,135 +52,138 @@ pub struct RoutingTable { struct BucketIndex(usize); impl BucketIndex { - /// Creates a new `BucketIndex` for a `Distance`. - /// - /// The given distance is interpreted as the distance from a `local_key` of - /// a `KBucketsTable`. If the distance is zero, `None` is returned, in - /// recognition of the fact that the only key with distance `0` to a - /// `local_key` is the `local_key` itself, which does not belong in any - /// bucket. - fn new(d: &Distance) -> Option { - d.ilog2().map(|i| BucketIndex(i as usize)) - } - - /// Gets the index value as an unsigned integer. - fn get(&self) -> usize { - self.0 - } - - /// Returns the minimum inclusive and maximum inclusive [`Distance`] - /// included in the bucket for this index. - fn _range(&self) -> (Distance, Distance) { - let min = Distance(U256::pow(U256::from(2), U256::from(self.0))); - if self.0 == usize::from(u8::MAX) { - (min, Distance(U256::MAX)) - } else { - let max = Distance(U256::pow(U256::from(2), U256::from(self.0 + 1)) - 1); - (min, max) - } - } - - /// Generates a random distance that falls into the bucket for this index. - #[cfg(test)] - fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance { - let mut bytes = [0u8; 32]; - let quot = self.0 / 8; - for i in 0..quot { - bytes[31 - i] = rng.gen(); - } - let rem = (self.0 % 8) as u32; - let lower = usize::pow(2, rem); - let upper = usize::pow(2, rem + 1); - // bytes[31 - quot] = rng.gen_range(lower, upper) as u8; - bytes[31 - quot] = rng.gen_range(lower..upper) as u8; - Distance(U256::from(bytes)) - } + /// Creates a new `BucketIndex` for a `Distance`. + /// + /// The given distance is interpreted as the distance from a `local_key` of + /// a `KBucketsTable`. If the distance is zero, `None` is returned, in + /// recognition of the fact that the only key with distance `0` to a + /// `local_key` is the `local_key` itself, which does not belong in any + /// bucket. + fn new(d: &Distance) -> Option { + d.ilog2().map(|i| BucketIndex(i as usize)) + } + + /// Gets the index value as an unsigned integer. + fn get(&self) -> usize { + self.0 + } + + /// Returns the minimum inclusive and maximum inclusive [`Distance`] + /// included in the bucket for this index. + fn _range(&self) -> (Distance, Distance) { + let min = Distance(U256::pow(U256::from(2), U256::from(self.0))); + if self.0 == usize::from(u8::MAX) { + (min, Distance(U256::MAX)) + } else { + let max = Distance(U256::pow(U256::from(2), U256::from(self.0 + 1)) - 1); + (min, max) + } + } + + /// Generates a random distance that falls into the bucket for this index. + #[cfg(test)] + fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance { + let mut bytes = [0u8; 32]; + let quot = self.0 / 8; + for i in 0..quot { + bytes[31 - i] = rng.gen(); + } + let rem = (self.0 % 8) as u32; + let lower = usize::pow(2, rem); + let upper = usize::pow(2, rem + 1); + // bytes[31 - quot] = rng.gen_range(lower, upper) as u8; + bytes[31 - quot] = rng.gen_range(lower..upper) as u8; + Distance(U256::from(bytes)) + } } impl RoutingTable { - /// Create new [`RoutingTable`]. - pub fn new(local_key: Key) -> Self { - RoutingTable { local_key, buckets: (0..NUM_BUCKETS).map(|_| KBucket::new()).collect() } - } - - /// Returns the local key. - pub fn _local_key(&self) -> &Key { - &self.local_key - } - - /// Get an entry for `peer` into a k-bucket. - pub fn entry<'a>(&'a mut self, key: Key) -> KBucketEntry<'a> { - let Some(index) = BucketIndex::new(&self.local_key.distance(&key)) else { - return KBucketEntry::LocalNode; - }; - - self.buckets[index.get()].entry(key) - } - - /// Add known peer to [`RoutingTable`]. - /// - /// In order to bootstrap the lookup process, the routing table must be aware of at least one - /// node and of its addresses. The insert operation is ignored - pub fn add_known_peer( - &mut self, - peer: PeerId, - addresses: Vec, - connection: ConnectionType, - ) { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?addresses, - ?connection, - "add known peer" - ); - - // TODO: this has to be moved elsewhere at some point - let addresses: Vec = addresses - .into_iter() - .filter_map(|address| { - let last = address.iter().last(); - if std::matches!(last, Some(Protocol::P2p(_))) { - Some(address) - } else { - Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?))) - } - }) - .collect(); - - match (self.entry(Key::from(peer)), addresses.is_empty()) { - (KBucketEntry::Occupied(entry), false) => { - entry.addresses = addresses; - }, - (mut entry @ KBucketEntry::Vacant(_), false) => { - entry.insert(KademliaPeer::new(peer, addresses, connection)); - }, - (KBucketEntry::LocalNode, _) => tracing::warn!( - target: LOG_TARGET, - ?peer, - "tried to add local node to routing table", - ), - (KBucketEntry::NoSlot, _) => tracing::trace!( - target: LOG_TARGET, - ?peer, - "routing table full, cannot add new entry", - ), - (_, true) => tracing::debug!( - target: LOG_TARGET, - ?peer, - "tried to add zero addresses to the routing table", - ), - } - } - - /// Get `limit` closests peers to `target` from the k-buckets. - pub fn closest(&mut self, target: Key, limit: usize) -> Vec { - ClosestBucketsIter::new(self.local_key.distance(&target)) - .map(|index| self.buckets[index.get()].closest_iter(&target)) - .flatten() - .take(limit) - .collect() - } + /// Create new [`RoutingTable`]. + pub fn new(local_key: Key) -> Self { + RoutingTable { + local_key, + buckets: (0..NUM_BUCKETS).map(|_| KBucket::new()).collect(), + } + } + + /// Returns the local key. + pub fn _local_key(&self) -> &Key { + &self.local_key + } + + /// Get an entry for `peer` into a k-bucket. + pub fn entry<'a>(&'a mut self, key: Key) -> KBucketEntry<'a> { + let Some(index) = BucketIndex::new(&self.local_key.distance(&key)) else { + return KBucketEntry::LocalNode; + }; + + self.buckets[index.get()].entry(key) + } + + /// Add known peer to [`RoutingTable`]. + /// + /// In order to bootstrap the lookup process, the routing table must be aware of at least one + /// node and of its addresses. The insert operation is ignored + pub fn add_known_peer( + &mut self, + peer: PeerId, + addresses: Vec, + connection: ConnectionType, + ) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?addresses, + ?connection, + "add known peer" + ); + + // TODO: this has to be moved elsewhere at some point + let addresses: Vec = addresses + .into_iter() + .filter_map(|address| { + let last = address.iter().last(); + if std::matches!(last, Some(Protocol::P2p(_))) { + Some(address) + } else { + Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?))) + } + }) + .collect(); + + match (self.entry(Key::from(peer)), addresses.is_empty()) { + (KBucketEntry::Occupied(entry), false) => { + entry.addresses = addresses; + } + (mut entry @ KBucketEntry::Vacant(_), false) => { + entry.insert(KademliaPeer::new(peer, addresses, connection)); + } + (KBucketEntry::LocalNode, _) => tracing::warn!( + target: LOG_TARGET, + ?peer, + "tried to add local node to routing table", + ), + (KBucketEntry::NoSlot, _) => tracing::trace!( + target: LOG_TARGET, + ?peer, + "routing table full, cannot add new entry", + ), + (_, true) => tracing::debug!( + target: LOG_TARGET, + ?peer, + "tried to add zero addresses to the routing table", + ), + } + } + + /// Get `limit` closests peers to `target` from the k-buckets. + pub fn closest(&mut self, target: Key, limit: usize) -> Vec { + ClosestBucketsIter::new(self.local_key.distance(&target)) + .map(|index| self.buckets[index.get()].closest_iter(&target)) + .flatten() + .take(limit) + .collect() + } } /// An iterator over the bucket indices, in the order determined by the `Distance` of a target from @@ -191,328 +194,336 @@ impl RoutingTable { /// /// [1]: https://github.com/libp2p/rust-libp2p/pull/1117#issuecomment-494694635 struct ClosestBucketsIter { - /// The distance to the `local_key`. - distance: Distance, - /// The current state of the iterator. - state: ClosestBucketsIterState, + /// The distance to the `local_key`. + distance: Distance, + /// The current state of the iterator. + state: ClosestBucketsIterState, } /// Operating states of a `ClosestBucketsIter`. enum ClosestBucketsIterState { - /// The starting state of the iterator yields the first bucket index and - /// then transitions to `ZoomIn`. - Start(BucketIndex), - /// The iterator "zooms in" to to yield the next bucket cotaining nodes that - /// are incrementally closer to the local node but further from the `target`. - /// These buckets are identified by a `1` in the corresponding bit position - /// of the distance bit string. When bucket `0` is reached, the iterator - /// transitions to `ZoomOut`. - ZoomIn(BucketIndex), - /// Once bucket `0` has been reached, the iterator starts "zooming out" - /// to buckets containing nodes that are incrementally further away from - /// both the local key and the target. These are identified by a `0` in - /// the corresponding bit position of the distance bit string. When bucket - /// `255` is reached, the iterator transitions to state `Done`. - ZoomOut(BucketIndex), - /// The iterator is in this state once it has visited all buckets. - Done, + /// The starting state of the iterator yields the first bucket index and + /// then transitions to `ZoomIn`. + Start(BucketIndex), + /// The iterator "zooms in" to to yield the next bucket cotaining nodes that + /// are incrementally closer to the local node but further from the `target`. + /// These buckets are identified by a `1` in the corresponding bit position + /// of the distance bit string. When bucket `0` is reached, the iterator + /// transitions to `ZoomOut`. + ZoomIn(BucketIndex), + /// Once bucket `0` has been reached, the iterator starts "zooming out" + /// to buckets containing nodes that are incrementally further away from + /// both the local key and the target. These are identified by a `0` in + /// the corresponding bit position of the distance bit string. When bucket + /// `255` is reached, the iterator transitions to state `Done`. + ZoomOut(BucketIndex), + /// The iterator is in this state once it has visited all buckets. + Done, } impl ClosestBucketsIter { - fn new(distance: Distance) -> Self { - let state = match BucketIndex::new(&distance) { - Some(i) => ClosestBucketsIterState::Start(i), - None => ClosestBucketsIterState::Start(BucketIndex(0)), - }; - Self { distance, state } - } - - fn next_in(&self, i: BucketIndex) -> Option { - (0..i.get()) - .rev() - .find_map(|i| self.distance.0.bit(i).then_some(BucketIndex(i))) - } - - fn next_out(&self, i: BucketIndex) -> Option { - (i.get() + 1..NUM_BUCKETS).find_map(|i| (!self.distance.0.bit(i)).then_some(BucketIndex(i))) - } + fn new(distance: Distance) -> Self { + let state = match BucketIndex::new(&distance) { + Some(i) => ClosestBucketsIterState::Start(i), + None => ClosestBucketsIterState::Start(BucketIndex(0)), + }; + Self { distance, state } + } + + fn next_in(&self, i: BucketIndex) -> Option { + (0..i.get()) + .rev() + .find_map(|i| self.distance.0.bit(i).then_some(BucketIndex(i))) + } + + fn next_out(&self, i: BucketIndex) -> Option { + (i.get() + 1..NUM_BUCKETS).find_map(|i| (!self.distance.0.bit(i)).then_some(BucketIndex(i))) + } } impl Iterator for ClosestBucketsIter { - type Item = BucketIndex; - - fn next(&mut self) -> Option { - match self.state { - ClosestBucketsIterState::Start(i) => { - self.state = ClosestBucketsIterState::ZoomIn(i); - Some(i) - }, - ClosestBucketsIterState::ZoomIn(i) => - if let Some(i) = self.next_in(i) { - self.state = ClosestBucketsIterState::ZoomIn(i); - Some(i) - } else { - let i = BucketIndex(0); - self.state = ClosestBucketsIterState::ZoomOut(i); - Some(i) - }, - ClosestBucketsIterState::ZoomOut(i) => - if let Some(i) = self.next_out(i) { - self.state = ClosestBucketsIterState::ZoomOut(i); - Some(i) - } else { - self.state = ClosestBucketsIterState::Done; - None - }, - ClosestBucketsIterState::Done => None, - } - } + type Item = BucketIndex; + + fn next(&mut self) -> Option { + match self.state { + ClosestBucketsIterState::Start(i) => { + self.state = ClosestBucketsIterState::ZoomIn(i); + Some(i) + } + ClosestBucketsIterState::ZoomIn(i) => + if let Some(i) = self.next_in(i) { + self.state = ClosestBucketsIterState::ZoomIn(i); + Some(i) + } else { + let i = BucketIndex(0); + self.state = ClosestBucketsIterState::ZoomOut(i); + Some(i) + }, + ClosestBucketsIterState::ZoomOut(i) => + if let Some(i) = self.next_out(i) { + self.state = ClosestBucketsIterState::ZoomOut(i); + Some(i) + } else { + self.state = ClosestBucketsIterState::Done; + None + }, + ClosestBucketsIterState::Done => None, + } + } } #[cfg(test)] mod tests { - use super::*; - use crate::protocol::libp2p::kademlia::types::ConnectionType; - - #[test] - fn closest_peers() { - let own_peer_id = PeerId::random(); - let own_key = Key::from(own_peer_id); - let mut table = RoutingTable::new(own_key.clone()); - - for _ in 0..60 { - let peer = PeerId::random(); - let key = Key::from(peer); - let mut entry = table.entry(key.clone()); - entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); - } - - let target = Key::from(PeerId::random()); - let closest = table.closest(target.clone(), 60usize); - let mut prev = None; - - for peer in &closest { - if let Some(value) = prev { - assert!(value < target.distance(&peer.key)); - } - - prev = Some(target.distance(&peer.key)); - } - } - - // generate random peer that falls in to specified k-bucket. - // - // NOTE: the preimage of the generated `Key` doesn't match the `Key` itself - fn random_peer( - rng: &mut impl rand::Rng, - own_key: Key, - bucket_index: usize, - ) -> (Key, PeerId) { - let peer = PeerId::random(); - let distance = BucketIndex(bucket_index).rand_distance(rng); - let key_bytes = own_key.for_distance(distance); - - (Key::from_bytes(key_bytes, peer), peer) - } - - #[test] - fn add_peer_to_empty_table() { - let own_peer_id = PeerId::random(); - let own_key = Key::from(own_peer_id); - let mut table = RoutingTable::new(own_key.clone()); - - // verify that local peer id resolves to special entry - assert_eq!(table.entry(own_key), KBucketEntry::LocalNode); - - let peer = PeerId::random(); - let key = Key::from(peer); - let mut test = table.entry(key.clone()); - let addresses = vec![]; - - assert!(std::matches!(test, KBucketEntry::Vacant(_))); - test.insert(KademliaPeer::new(peer, addresses.clone(), ConnectionType::Connected)); - - assert_eq!( - table.entry(key.clone()), - KBucketEntry::Occupied(&mut KademliaPeer::new( - peer, - addresses.clone(), - ConnectionType::Connected, - )) - ); - - match table.entry(key.clone()) { - KBucketEntry::Occupied(entry) => { - entry.connection = ConnectionType::NotConnected; - }, - state => panic!("invalid state for `KBucketEntry`: {state:?}"), - } - - assert_eq!( - table.entry(key.clone()), - KBucketEntry::Occupied(&mut KademliaPeer::new( - peer, - addresses, - ConnectionType::NotConnected, - )) - ); - } - - #[test] - fn full_k_bucket() { - let mut rng = rand::thread_rng(); - let own_peer_id = PeerId::random(); - let own_key = Key::from(own_peer_id); - let mut table = RoutingTable::new(own_key.clone()); - - // add 20 nodes to the same k-bucket - for _ in 0..20 { - let (key, peer) = random_peer(&mut rng, own_key.clone(), 254); - let mut entry = table.entry(key.clone()); - - assert!(std::matches!(entry, KBucketEntry::Vacant(_))); - entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); - } - - // try to add another peer and verify the peer is rejected - // because the k-bucket is full of connected nodes - let peer = PeerId::random(); - let distance = BucketIndex(254).rand_distance(&mut rng); - let key_bytes = own_key.for_distance(distance); - let key = Key::from_bytes(key_bytes, peer); - - let entry = table.entry(key.clone()); - assert!(std::matches!(entry, KBucketEntry::NoSlot)); - } - - #[test] - #[ignore] - fn peer_disconnects_and_is_evicted() { - let mut rng = rand::thread_rng(); - let own_peer_id = PeerId::random(); - let own_key = Key::from(own_peer_id); - let mut table = RoutingTable::new(own_key.clone()); - - // add 20 nodes to the same k-bucket - let peers = (0..20) - .map(|_| { - let (key, peer) = random_peer(&mut rng, own_key.clone(), 253); - let mut entry = table.entry(key.clone()); - - assert!(std::matches!(entry, KBucketEntry::Vacant(_))); - entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); - - (peer, key) - }) - .collect::>(); - - // try to add another peer and verify the peer is rejected - // because the k-bucket is full of connected nodes - let peer = PeerId::random(); - let distance = BucketIndex(253).rand_distance(&mut rng); - let key_bytes = own_key.for_distance(distance); - let key = Key::from_bytes(key_bytes, peer); - - let entry = table.entry(key.clone()); - assert!(std::matches!(entry, KBucketEntry::NoSlot)); - - // disconnect random peer - match table.entry(peers[3].1.clone()) { - KBucketEntry::Occupied(entry) => { - entry.connection = ConnectionType::NotConnected; - }, - _ => panic!("invalid state for node"), - } - - // try to add the previously rejected peer again and verify it's added - let mut entry = table.entry(key.clone()); - assert!(std::matches!(entry, KBucketEntry::Vacant(_))); - entry.insert(KademliaPeer::new( - peer, - vec!["/ip6/::1/tcp/8888".parse().unwrap()], - ConnectionType::CanConnect, - )); - - // verify the node is still there - let entry = table.entry(key.clone()); - let addresses = vec!["/ip6/::1/tcp/8888".parse().unwrap()]; - assert_eq!( - entry, - KBucketEntry::Occupied(&mut KademliaPeer::new( - peer, - addresses, - ConnectionType::CanConnect, - )) - ); - } - - #[test] - fn disconnected_peers_are_not_evicted_if_there_is_capacity() { - let mut rng = rand::thread_rng(); - let own_peer_id = PeerId::random(); - let own_key = Key::from(own_peer_id); - let mut table = RoutingTable::new(own_key.clone()); - - // add 19 disconnected nodes to the same k-bucket - let _peers = (0..19) - .map(|_| { - let (key, peer) = random_peer(&mut rng, own_key.clone(), 252); - let mut entry = table.entry(key.clone()); - - assert!(std::matches!(entry, KBucketEntry::Vacant(_))); - entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::NotConnected)); - - (peer, key) - }) - .collect::>(); - - // try to add another peer and verify it's accepted as there is - // still room in the k-bucket for the node - let peer = PeerId::random(); - let distance = BucketIndex(252).rand_distance(&mut rng); - let key_bytes = own_key.for_distance(distance); - let key = Key::from_bytes(key_bytes, peer); - - let mut entry = table.entry(key.clone()); - assert!(std::matches!(entry, KBucketEntry::Vacant(_))); - entry.insert(KademliaPeer::new( - peer, - vec!["/ip6/::1/tcp/8888".parse().unwrap()], - ConnectionType::CanConnect, - )); - } - - #[test] - fn closest_buckets_iterator_set_lsb() { - // Test zooming-in & zooming-out of the iterator using a toy example with set LSB. - let d = Distance(U256::from(0b10011011)); - let mut iter = ClosestBucketsIter::new(d); - // Note that bucket 0 is visited twice. This is, technically, a bug, but to not complicate - // the implementation and keep it consistent with `libp2p` it's kept as is. There are - // virtually no practical consequences of this, because to have bucket 0 populated we have - // to encounter two sha256 hash values differing only in one least significant bit. - let expected_buckets = vec![7, 4, 3, 1, 0, 0, 2, 5, 6] - .into_iter() - .chain(8..=255) - .map(|i| BucketIndex(i)); - for expected in expected_buckets { - let got = iter.next().unwrap(); - assert_eq!(got, expected); - } - assert!(iter.next().is_none()); - } - - #[test] - fn closest_buckets_iterator_unset_lsb() { - // Test zooming-in & zooming-out of the iterator using a toy example with unset LSB. - let d = Distance(U256::from(0b01011010)); - let mut iter = ClosestBucketsIter::new(d); - let expected_buckets = - vec![6, 4, 3, 1, 0, 2, 5, 7].into_iter().chain(8..=255).map(|i| BucketIndex(i)); - for expected in expected_buckets { - let got = iter.next().unwrap(); - assert_eq!(got, expected); - } - assert!(iter.next().is_none()); - } + use super::*; + use crate::protocol::libp2p::kademlia::types::ConnectionType; + + #[test] + fn closest_peers() { + let own_peer_id = PeerId::random(); + let own_key = Key::from(own_peer_id); + let mut table = RoutingTable::new(own_key.clone()); + + for _ in 0..60 { + let peer = PeerId::random(); + let key = Key::from(peer); + let mut entry = table.entry(key.clone()); + entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); + } + + let target = Key::from(PeerId::random()); + let closest = table.closest(target.clone(), 60usize); + let mut prev = None; + + for peer in &closest { + if let Some(value) = prev { + assert!(value < target.distance(&peer.key)); + } + + prev = Some(target.distance(&peer.key)); + } + } + + // generate random peer that falls in to specified k-bucket. + // + // NOTE: the preimage of the generated `Key` doesn't match the `Key` itself + fn random_peer( + rng: &mut impl rand::Rng, + own_key: Key, + bucket_index: usize, + ) -> (Key, PeerId) { + let peer = PeerId::random(); + let distance = BucketIndex(bucket_index).rand_distance(rng); + let key_bytes = own_key.for_distance(distance); + + (Key::from_bytes(key_bytes, peer), peer) + } + + #[test] + fn add_peer_to_empty_table() { + let own_peer_id = PeerId::random(); + let own_key = Key::from(own_peer_id); + let mut table = RoutingTable::new(own_key.clone()); + + // verify that local peer id resolves to special entry + assert_eq!(table.entry(own_key), KBucketEntry::LocalNode); + + let peer = PeerId::random(); + let key = Key::from(peer); + let mut test = table.entry(key.clone()); + let addresses = vec![]; + + assert!(std::matches!(test, KBucketEntry::Vacant(_))); + test.insert(KademliaPeer::new( + peer, + addresses.clone(), + ConnectionType::Connected, + )); + + assert_eq!( + table.entry(key.clone()), + KBucketEntry::Occupied(&mut KademliaPeer::new( + peer, + addresses.clone(), + ConnectionType::Connected, + )) + ); + + match table.entry(key.clone()) { + KBucketEntry::Occupied(entry) => { + entry.connection = ConnectionType::NotConnected; + } + state => panic!("invalid state for `KBucketEntry`: {state:?}"), + } + + assert_eq!( + table.entry(key.clone()), + KBucketEntry::Occupied(&mut KademliaPeer::new( + peer, + addresses, + ConnectionType::NotConnected, + )) + ); + } + + #[test] + fn full_k_bucket() { + let mut rng = rand::thread_rng(); + let own_peer_id = PeerId::random(); + let own_key = Key::from(own_peer_id); + let mut table = RoutingTable::new(own_key.clone()); + + // add 20 nodes to the same k-bucket + for _ in 0..20 { + let (key, peer) = random_peer(&mut rng, own_key.clone(), 254); + let mut entry = table.entry(key.clone()); + + assert!(std::matches!(entry, KBucketEntry::Vacant(_))); + entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); + } + + // try to add another peer and verify the peer is rejected + // because the k-bucket is full of connected nodes + let peer = PeerId::random(); + let distance = BucketIndex(254).rand_distance(&mut rng); + let key_bytes = own_key.for_distance(distance); + let key = Key::from_bytes(key_bytes, peer); + + let entry = table.entry(key.clone()); + assert!(std::matches!(entry, KBucketEntry::NoSlot)); + } + + #[test] + #[ignore] + fn peer_disconnects_and_is_evicted() { + let mut rng = rand::thread_rng(); + let own_peer_id = PeerId::random(); + let own_key = Key::from(own_peer_id); + let mut table = RoutingTable::new(own_key.clone()); + + // add 20 nodes to the same k-bucket + let peers = (0..20) + .map(|_| { + let (key, peer) = random_peer(&mut rng, own_key.clone(), 253); + let mut entry = table.entry(key.clone()); + + assert!(std::matches!(entry, KBucketEntry::Vacant(_))); + entry.insert(KademliaPeer::new(peer, vec![], ConnectionType::Connected)); + + (peer, key) + }) + .collect::>(); + + // try to add another peer and verify the peer is rejected + // because the k-bucket is full of connected nodes + let peer = PeerId::random(); + let distance = BucketIndex(253).rand_distance(&mut rng); + let key_bytes = own_key.for_distance(distance); + let key = Key::from_bytes(key_bytes, peer); + + let entry = table.entry(key.clone()); + assert!(std::matches!(entry, KBucketEntry::NoSlot)); + + // disconnect random peer + match table.entry(peers[3].1.clone()) { + KBucketEntry::Occupied(entry) => { + entry.connection = ConnectionType::NotConnected; + } + _ => panic!("invalid state for node"), + } + + // try to add the previously rejected peer again and verify it's added + let mut entry = table.entry(key.clone()); + assert!(std::matches!(entry, KBucketEntry::Vacant(_))); + entry.insert(KademliaPeer::new( + peer, + vec!["/ip6/::1/tcp/8888".parse().unwrap()], + ConnectionType::CanConnect, + )); + + // verify the node is still there + let entry = table.entry(key.clone()); + let addresses = vec!["/ip6/::1/tcp/8888".parse().unwrap()]; + assert_eq!( + entry, + KBucketEntry::Occupied(&mut KademliaPeer::new( + peer, + addresses, + ConnectionType::CanConnect, + )) + ); + } + + #[test] + fn disconnected_peers_are_not_evicted_if_there_is_capacity() { + let mut rng = rand::thread_rng(); + let own_peer_id = PeerId::random(); + let own_key = Key::from(own_peer_id); + let mut table = RoutingTable::new(own_key.clone()); + + // add 19 disconnected nodes to the same k-bucket + let _peers = (0..19) + .map(|_| { + let (key, peer) = random_peer(&mut rng, own_key.clone(), 252); + let mut entry = table.entry(key.clone()); + + assert!(std::matches!(entry, KBucketEntry::Vacant(_))); + entry.insert(KademliaPeer::new( + peer, + vec![], + ConnectionType::NotConnected, + )); + + (peer, key) + }) + .collect::>(); + + // try to add another peer and verify it's accepted as there is + // still room in the k-bucket for the node + let peer = PeerId::random(); + let distance = BucketIndex(252).rand_distance(&mut rng); + let key_bytes = own_key.for_distance(distance); + let key = Key::from_bytes(key_bytes, peer); + + let mut entry = table.entry(key.clone()); + assert!(std::matches!(entry, KBucketEntry::Vacant(_))); + entry.insert(KademliaPeer::new( + peer, + vec!["/ip6/::1/tcp/8888".parse().unwrap()], + ConnectionType::CanConnect, + )); + } + + #[test] + fn closest_buckets_iterator_set_lsb() { + // Test zooming-in & zooming-out of the iterator using a toy example with set LSB. + let d = Distance(U256::from(0b10011011)); + let mut iter = ClosestBucketsIter::new(d); + // Note that bucket 0 is visited twice. This is, technically, a bug, but to not complicate + // the implementation and keep it consistent with `libp2p` it's kept as is. There are + // virtually no practical consequences of this, because to have bucket 0 populated we have + // to encounter two sha256 hash values differing only in one least significant bit. + let expected_buckets = vec![7, 4, 3, 1, 0, 0, 2, 5, 6] + .into_iter() + .chain(8..=255) + .map(|i| BucketIndex(i)); + for expected in expected_buckets { + let got = iter.next().unwrap(); + assert_eq!(got, expected); + } + assert!(iter.next().is_none()); + } + + #[test] + fn closest_buckets_iterator_unset_lsb() { + // Test zooming-in & zooming-out of the iterator using a toy example with unset LSB. + let d = Distance(U256::from(0b01011010)); + let mut iter = ClosestBucketsIter::new(d); + let expected_buckets = + vec![6, 4, 3, 1, 0, 2, 5, 7].into_iter().chain(8..=255).map(|i| BucketIndex(i)); + for expected in expected_buckets { + let got = iter.next().unwrap(); + assert_eq!(got, expected); + } + assert!(iter.next().is_none()); + } } diff --git a/src/protocol/libp2p/kademlia/store.rs b/src/protocol/libp2p/kademlia/store.rs index 9cfa47c8..6fd158e1 100644 --- a/src/protocol/libp2p/kademlia/store.rs +++ b/src/protocol/libp2p/kademlia/store.rs @@ -30,28 +30,30 @@ pub enum MemoryStoreEvent {} /// Memory store. pub struct MemoryStore { - /// Records. - records: HashMap, + /// Records. + records: HashMap, } impl MemoryStore { - /// Create new [`MemoryStore`]. - pub fn new() -> Self { - Self { records: HashMap::new() } - } - - /// Try to get record from local store for `key`. - pub fn get(&self, key: &Key) -> Option<&Record> { - self.records.get(key) - } - - /// Store record. - pub fn put(&mut self, record: Record) { - self.records.insert(record.key.clone(), record); - } - - /// Poll next event from the store. - async fn next_event() -> Option { - None - } + /// Create new [`MemoryStore`]. + pub fn new() -> Self { + Self { + records: HashMap::new(), + } + } + + /// Try to get record from local store for `key`. + pub fn get(&self, key: &Key) -> Option<&Record> { + self.records.get(key) + } + + /// Store record. + pub fn put(&mut self, record: Record) { + self.records.insert(record.key.clone(), record); + } + + /// Poll next event from the store. + async fn next_event() -> Option { + None + } } diff --git a/src/protocol/libp2p/kademlia/types.rs b/src/protocol/libp2p/kademlia/types.rs index 58cc3284..fe9d04eb 100644 --- a/src/protocol/libp2p/kademlia/types.rs +++ b/src/protocol/libp2p/kademlia/types.rs @@ -25,19 +25,19 @@ use crate::{protocol::libp2p::kademlia::schema, PeerId}; use multiaddr::Multiaddr; use sha2::{ - digest::generic_array::{typenum::U32, GenericArray}, - Digest, Sha256, + digest::generic_array::{typenum::U32, GenericArray}, + Digest, Sha256, }; use uint::*; use std::{ - borrow::Borrow, - hash::{Hash, Hasher}, + borrow::Borrow, + hash::{Hash, Hasher}, }; construct_uint! { - /// 256-bit unsigned integer. - pub(super) struct U256(4); + /// 256-bit unsigned integer. + pub(super) struct U256(4); } /// A `Key` in the DHT keyspace with preserved preimage. @@ -49,93 +49,96 @@ construct_uint! { /// the hash digests, interpreted as an integer. See [`Key::distance`]. #[derive(Clone, Debug)] pub struct Key { - _preimage: T, - bytes: KeyBytes, + _preimage: T, + bytes: KeyBytes, } impl Key { - /// Constructs a new `Key` by running the given value through a random - /// oracle. - /// - /// The preimage of type `T` is preserved. - /// See [`Key::into_preimage`] for more details. - pub fn new(_preimage: T) -> Key - where - T: Borrow<[u8]>, - { - let bytes = KeyBytes::new(_preimage.borrow()); - Key { _preimage, bytes } - } - - /// Convert [`Key`] into its preimage. - pub fn into_preimage(self) -> T { - self._preimage - } - - /// Computes the distance of the keys according to the XOR metric. - pub fn distance(&self, other: &U) -> Distance - where - U: AsRef, - { - self.bytes.distance(other) - } - - /// Returns the uniquely determined key with the given distance to `self`. - /// - /// This implements the following equivalence: - /// - /// `self xor other = distance <==> other = self xor distance` - #[cfg(test)] - pub fn for_distance(&self, d: Distance) -> KeyBytes { - self.bytes.for_distance(d) - } - - /// Generate key from `KeyBytes` with a random preimage. - /// - /// Only used for testing - #[cfg(test)] - pub fn from_bytes(bytes: KeyBytes, _preimage: T) -> Key { - Self { bytes, _preimage } - } + /// Constructs a new `Key` by running the given value through a random + /// oracle. + /// + /// The preimage of type `T` is preserved. + /// See [`Key::into_preimage`] for more details. + pub fn new(_preimage: T) -> Key + where + T: Borrow<[u8]>, + { + let bytes = KeyBytes::new(_preimage.borrow()); + Key { _preimage, bytes } + } + + /// Convert [`Key`] into its preimage. + pub fn into_preimage(self) -> T { + self._preimage + } + + /// Computes the distance of the keys according to the XOR metric. + pub fn distance(&self, other: &U) -> Distance + where + U: AsRef, + { + self.bytes.distance(other) + } + + /// Returns the uniquely determined key with the given distance to `self`. + /// + /// This implements the following equivalence: + /// + /// `self xor other = distance <==> other = self xor distance` + #[cfg(test)] + pub fn for_distance(&self, d: Distance) -> KeyBytes { + self.bytes.for_distance(d) + } + + /// Generate key from `KeyBytes` with a random preimage. + /// + /// Only used for testing + #[cfg(test)] + pub fn from_bytes(bytes: KeyBytes, _preimage: T) -> Key { + Self { bytes, _preimage } + } } impl From> for KeyBytes { - fn from(key: Key) -> KeyBytes { - key.bytes - } + fn from(key: Key) -> KeyBytes { + key.bytes + } } impl From for Key { - fn from(p: PeerId) -> Self { - let bytes = KeyBytes(Sha256::digest(p.to_bytes())); - Key { _preimage: p, bytes } - } + fn from(p: PeerId) -> Self { + let bytes = KeyBytes(Sha256::digest(p.to_bytes())); + Key { + _preimage: p, + bytes, + } + } } impl From> for Key> { - fn from(b: Vec) -> Self { - Key::new(b) - } + fn from(b: Vec) -> Self { + Key::new(b) + } } impl AsRef for Key { - fn as_ref(&self) -> &KeyBytes { - &self.bytes - } + fn as_ref(&self) -> &KeyBytes { + &self.bytes + } } impl PartialEq> for Key { - fn eq(&self, other: &Key) -> bool { - self.bytes == other.bytes - } + fn eq(&self, other: &Key) -> bool { + self.bytes == other.bytes + } } impl Eq for Key {} impl Hash for Key { - fn hash(&self, state: &mut H) { - self.bytes.0.hash(state); - } + fn hash(&self, state: &mut H) { + self.bytes.0.hash(state); + } } /// The raw bytes of a key in the DHT keyspace. @@ -143,41 +146,41 @@ impl Hash for Key { pub struct KeyBytes(GenericArray); impl KeyBytes { - /// Creates a new key in the DHT keyspace by running the given - /// value through a random oracle. - pub fn new(value: T) -> Self - where - T: Borrow<[u8]>, - { - KeyBytes(Sha256::digest(value.borrow())) - } - - /// Computes the distance of the keys according to the XOR metric. - pub fn distance(&self, other: &U) -> Distance - where - U: AsRef, - { - let a = U256::from(self.0.as_slice()); - let b = U256::from(other.as_ref().0.as_slice()); - Distance(a ^ b) - } - - /// Returns the uniquely determined key with the given distance to `self`. - /// - /// This implements the following equivalence: - /// - /// `self xor other = distance <==> other = self xor distance` - #[cfg(test)] - pub fn for_distance(&self, d: Distance) -> KeyBytes { - let key_int = U256::from(self.0.as_slice()) ^ d.0; - KeyBytes(GenericArray::from(<[u8; 32]>::from(key_int))) - } + /// Creates a new key in the DHT keyspace by running the given + /// value through a random oracle. + pub fn new(value: T) -> Self + where + T: Borrow<[u8]>, + { + KeyBytes(Sha256::digest(value.borrow())) + } + + /// Computes the distance of the keys according to the XOR metric. + pub fn distance(&self, other: &U) -> Distance + where + U: AsRef, + { + let a = U256::from(self.0.as_slice()); + let b = U256::from(other.as_ref().0.as_slice()); + Distance(a ^ b) + } + + /// Returns the uniquely determined key with the given distance to `self`. + /// + /// This implements the following equivalence: + /// + /// `self xor other = distance <==> other = self xor distance` + #[cfg(test)] + pub fn for_distance(&self, d: Distance) -> KeyBytes { + let key_int = U256::from(self.0.as_slice()) ^ d.0; + KeyBytes(GenericArray::from(<[u8; 32]>::from(key_int))) + } } impl AsRef for KeyBytes { - fn as_ref(&self) -> &KeyBytes { - self - } + fn as_ref(&self) -> &KeyBytes { + self + } } /// A distance between two keys in the DHT keyspace. @@ -185,103 +188,108 @@ impl AsRef for KeyBytes { pub struct Distance(pub(super) U256); impl Distance { - /// Returns the integer part of the base 2 logarithm of the [`Distance`]. - /// - /// Returns `None` if the distance is zero. - pub fn ilog2(&self) -> Option { - (256 - self.0.leading_zeros()).checked_sub(1) - } + /// Returns the integer part of the base 2 logarithm of the [`Distance`]. + /// + /// Returns `None` if the distance is zero. + pub fn ilog2(&self) -> Option { + (256 - self.0.leading_zeros()).checked_sub(1) + } } /// Connection type to peer. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ConnectionType { - /// Sender does not have a connection to peer. - NotConnected, + /// Sender does not have a connection to peer. + NotConnected, - /// Sender is connected to the peer. - Connected, + /// Sender is connected to the peer. + Connected, - /// Sender has recently been connected to the peer. - CanConnect, + /// Sender has recently been connected to the peer. + CanConnect, - /// Sender is unable to connect to the peer. - CannotConnect, + /// Sender is unable to connect to the peer. + CannotConnect, } impl TryFrom for ConnectionType { - type Error = (); - - fn try_from(value: i32) -> Result { - match value { - 0 => Ok(ConnectionType::NotConnected), - 1 => Ok(ConnectionType::Connected), - 2 => Ok(ConnectionType::CanConnect), - 3 => Ok(ConnectionType::CannotConnect), - _ => Err(()), - } - } + type Error = (); + + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(ConnectionType::NotConnected), + 1 => Ok(ConnectionType::Connected), + 2 => Ok(ConnectionType::CanConnect), + 3 => Ok(ConnectionType::CannotConnect), + _ => Err(()), + } + } } impl From for i32 { - fn from(connection: ConnectionType) -> Self { - match connection { - ConnectionType::NotConnected => 0, - ConnectionType::Connected => 1, - ConnectionType::CanConnect => 2, - ConnectionType::CannotConnect => 3, - } - } + fn from(connection: ConnectionType) -> Self { + match connection { + ConnectionType::NotConnected => 0, + ConnectionType::Connected => 1, + ConnectionType::CanConnect => 2, + ConnectionType::CannotConnect => 3, + } + } } /// Kademlia peer. #[derive(Debug, Clone, PartialEq, Eq)] pub struct KademliaPeer { - /// Peer key. - pub(super) key: Key, + /// Peer key. + pub(super) key: Key, - /// Peer ID. - pub(super) peer: PeerId, + /// Peer ID. + pub(super) peer: PeerId, - /// Known addresses of peer. - pub(super) addresses: Vec, + /// Known addresses of peer. + pub(super) addresses: Vec, - /// Connection type. - pub(super) connection: ConnectionType, + /// Connection type. + pub(super) connection: ConnectionType, } impl KademliaPeer { - /// Create new [`KademliaPeer`]. - pub fn new(peer: PeerId, addresses: Vec, connection: ConnectionType) -> Self { - Self { peer, addresses, connection, key: Key::from(peer) } - } + /// Create new [`KademliaPeer`]. + pub fn new(peer: PeerId, addresses: Vec, connection: ConnectionType) -> Self { + Self { + peer, + addresses, + connection, + key: Key::from(peer), + } + } } impl TryFrom<&schema::kademlia::Peer> for KademliaPeer { - type Error = (); - - fn try_from(record: &schema::kademlia::Peer) -> Result { - let peer = PeerId::from_bytes(&record.id).map_err(|_| ())?; - - Ok(KademliaPeer { - key: Key::from(peer), - peer, - addresses: record - .addrs - .iter() - .filter_map(|address| Multiaddr::try_from(address.clone()).ok()) - .collect(), - connection: ConnectionType::try_from(record.connection)?, - }) - } + type Error = (); + + fn try_from(record: &schema::kademlia::Peer) -> Result { + let peer = PeerId::from_bytes(&record.id).map_err(|_| ())?; + + Ok(KademliaPeer { + key: Key::from(peer), + peer, + addresses: record + .addrs + .iter() + .filter_map(|address| Multiaddr::try_from(address.clone()).ok()) + .collect(), + connection: ConnectionType::try_from(record.connection)?, + }) + } } impl From<&KademliaPeer> for schema::kademlia::Peer { - fn from(peer: &KademliaPeer) -> Self { - schema::kademlia::Peer { - id: peer.peer.to_bytes(), - addrs: peer.addresses.iter().map(|address| address.to_vec()).collect(), - connection: peer.connection.into(), - } - } + fn from(peer: &KademliaPeer) -> Self { + schema::kademlia::Peer { + id: peer.peer.to_bytes(), + addrs: peer.addresses.iter().map(|address| address.to_vec()).collect(), + connection: peer.connection.into(), + } + } } diff --git a/src/protocol/libp2p/ping/config.rs b/src/protocol/libp2p/ping/config.rs index e35c58e5..ba507b20 100644 --- a/src/protocol/libp2p/ping/config.rs +++ b/src/protocol/libp2p/ping/config.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, protocol::libp2p::ping::PingEvent, types::protocol::ProtocolName, - DEFAULT_CHANNEL_SIZE, + codec::ProtocolCodec, protocol::libp2p::ping::PingEvent, types::protocol::ProtocolName, + DEFAULT_CHANNEL_SIZE, }; use futures::Stream; @@ -38,78 +38,78 @@ const MAX_FAILURES: usize = 3; /// Ping configuration. pub struct Config { - /// Protocol name. - pub(crate) protocol: ProtocolName, + /// Protocol name. + pub(crate) protocol: ProtocolName, - /// Codec used by the protocol. - pub(crate) codec: ProtocolCodec, + /// Codec used by the protocol. + pub(crate) codec: ProtocolCodec, - /// Maximum failures before the peer is considered unreachable. - pub(crate) max_failures: usize, + /// Maximum failures before the peer is considered unreachable. + pub(crate) max_failures: usize, - /// TX channel for sending events to the user protocol. - pub(crate) tx_event: Sender, + /// TX channel for sending events to the user protocol. + pub(crate) tx_event: Sender, } impl Config { - /// Create new [`Config`] with default values. - /// - /// Returns a config that is given to `Litep2pConfig` and an event stream for [`PingEvent`]s. - pub fn default() -> (Self, Box + Send + Unpin>) { - let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); - - ( - Self { - tx_event, - max_failures: MAX_FAILURES, - protocol: ProtocolName::from(PROTOCOL_NAME), - codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), - }, - Box::new(ReceiverStream::new(rx_event)), - ) - } + /// Create new [`Config`] with default values. + /// + /// Returns a config that is given to `Litep2pConfig` and an event stream for [`PingEvent`]s. + pub fn default() -> (Self, Box + Send + Unpin>) { + let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); + + ( + Self { + tx_event, + max_failures: MAX_FAILURES, + protocol: ProtocolName::from(PROTOCOL_NAME), + codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), + }, + Box::new(ReceiverStream::new(rx_event)), + ) + } } /// Ping configuration builder. pub struct ConfigBuilder { - /// Protocol name. - protocol: ProtocolName, + /// Protocol name. + protocol: ProtocolName, - /// Codec used by the protocol. - codec: ProtocolCodec, + /// Codec used by the protocol. + codec: ProtocolCodec, - /// Maximum failures before the peer is considered unreachable. - max_failures: usize, + /// Maximum failures before the peer is considered unreachable. + max_failures: usize, } impl ConfigBuilder { - /// Create new default [`Config`] which can be modified by the user. - pub fn new() -> Self { - Self { - max_failures: MAX_FAILURES, - protocol: ProtocolName::from(PROTOCOL_NAME), - codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), - } - } - - /// Set maximum failures the protocol. - pub fn with_max_failure(mut self, max_failures: usize) -> Self { - self.max_failures = max_failures; - self - } - - /// Build [`Config`]. - pub fn build(self) -> (Config, Box + Send + Unpin>) { - let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); - - ( - Config { - tx_event, - max_failures: self.max_failures, - protocol: self.protocol, - codec: self.codec, - }, - Box::new(ReceiverStream::new(rx_event)), - ) - } + /// Create new default [`Config`] which can be modified by the user. + pub fn new() -> Self { + Self { + max_failures: MAX_FAILURES, + protocol: ProtocolName::from(PROTOCOL_NAME), + codec: ProtocolCodec::Identity(PING_PAYLOAD_SIZE), + } + } + + /// Set maximum failures the protocol. + pub fn with_max_failure(mut self, max_failures: usize) -> Self { + self.max_failures = max_failures; + self + } + + /// Build [`Config`]. + pub fn build(self) -> (Config, Box + Send + Unpin>) { + let (tx_event, rx_event) = channel(DEFAULT_CHANNEL_SIZE); + + ( + Config { + tx_event, + max_failures: self.max_failures, + protocol: self.protocol, + codec: self.codec, + }, + Box::new(ReceiverStream::new(rx_event)), + ) + } } diff --git a/src/protocol/libp2p/ping/mod.rs b/src/protocol/libp2p/ping/mod.rs index f12dedf3..4700030d 100644 --- a/src/protocol/libp2p/ping/mod.rs +++ b/src/protocol/libp2p/ping/mod.rs @@ -21,19 +21,19 @@ //! [`/ipfs/ping/1.0.0`](https://github.com/libp2p/specs/blob/master/ping/ping.md) implementation. use crate::{ - error::{Error, SubstreamError}, - protocol::{Direction, TransportEvent, TransportService}, - substream::Substream, - types::SubstreamId, - PeerId, + error::{Error, SubstreamError}, + protocol::{Direction, TransportEvent, TransportService}, + substream::Substream, + types::SubstreamId, + PeerId, }; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use tokio::sync::mpsc::Sender; use std::{ - collections::{HashMap, HashSet}, - time::{Duration, Instant}, + collections::{HashMap, HashSet}, + time::{Duration, Instant}, }; pub use config::{Config, ConfigBuilder}; @@ -48,179 +48,179 @@ const LOG_TARGET: &str = "litep2p::ipfs::ping"; /// Events emitted by the ping protocol. #[derive(Debug)] pub enum PingEvent { - /// Ping time with remote peer. - Ping { - /// Peer ID. - peer: PeerId, - - /// Measured ping time with the peer. - ping: Duration, - }, + /// Ping time with remote peer. + Ping { + /// Peer ID. + peer: PeerId, + + /// Measured ping time with the peer. + ping: Duration, + }, } /// Ping protocol. pub(crate) struct Ping { - /// Maximum failures before the peer is considered unreachable. - _max_failures: usize, + /// Maximum failures before the peer is considered unreachable. + _max_failures: usize, - // Connection service. - service: TransportService, + // Connection service. + service: TransportService, - /// TX channel for sending events to the user protocol. - tx: Sender, + /// TX channel for sending events to the user protocol. + tx: Sender, - /// Connected peers. - peers: HashSet, + /// Connected peers. + peers: HashSet, - /// Pending outbound substreams. - pending_opens: HashMap, + /// Pending outbound substreams. + pending_opens: HashMap, - /// Pending outbound substreams. - pending_outbound: FuturesUnordered>>, + /// Pending outbound substreams. + pending_outbound: FuturesUnordered>>, - /// Pending inbound substreams. - pending_inbound: FuturesUnordered>>, + /// Pending inbound substreams. + pending_inbound: FuturesUnordered>>, } impl Ping { - /// Create new [`Ping`] protocol. - pub fn new(service: TransportService, config: Config) -> Self { - Self { - service, - tx: config.tx_event, - peers: HashSet::new(), - pending_opens: HashMap::new(), - pending_outbound: FuturesUnordered::new(), - pending_inbound: FuturesUnordered::new(), - _max_failures: config.max_failures, - } - } - - /// Connection established to remote peer. - fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, "connection established"); - - let substream_id = self.service.open_substream(peer)?; - self.pending_opens.insert(substream_id, peer); - self.peers.insert(peer); - - Ok(()) - } - - /// Connection closed to remote peer. - fn on_connection_closed(&mut self, peer: PeerId) { - tracing::trace!(target: LOG_TARGET, ?peer, "connection closed"); - - self.peers.remove(&peer); - } - - /// Handle outbound substream. - fn on_outbound_substream( - &mut self, - peer: PeerId, - substream_id: SubstreamId, - mut substream: Substream, - ) { - tracing::trace!(target: LOG_TARGET, ?peer, "handle outbound substream"); - - self.pending_outbound.push(Box::pin(async move { - let future = async move { - // TODO: generate random payload and verify it - let _ = substream.send_framed(vec![0u8; 32].into()).await?; - let now = Instant::now(); - let _ = substream.next().await.ok_or(Error::SubstreamError( - SubstreamError::ReadFailure(Some(substream_id)), - ))?; - let _ = substream.close().await; - - Ok(now.elapsed()) - }; - - match tokio::time::timeout(Duration::from_secs(10), future).await { - Err(_) => return Err(Error::Timeout), - Ok(Err(error)) => return Err(error), - Ok(Ok(elapsed)) => Ok((peer, elapsed)), - } - })); - } - - /// Substream opened to remote peer. - fn on_inbound_substream(&mut self, peer: PeerId, mut substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "handle inbound substream"); - - self.pending_inbound.push(Box::pin(async move { - let future = async move { - let payload = substream - .next() - .await - .ok_or(Error::SubstreamError(SubstreamError::ReadFailure(None)))??; - substream.send_framed(payload.freeze()).await?; - let _ = substream.next().await.map(|_| ()); - - Ok(()) - }; - - match tokio::time::timeout(Duration::from_secs(10), future).await { - Err(_) => return Err(Error::Timeout), - Ok(Err(error)) => return Err(error), - Ok(Ok(())) => Ok(()), - } - })); - } - - /// Start [`Ping`] event loop. - pub async fn run(mut self) { - tracing::debug!(target: LOG_TARGET, "starting ping event loop"); - - loop { - tokio::select! { - event = self.service.next() => match event { - Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - let _ = self.on_connection_established(peer); - } - Some(TransportEvent::ConnectionClosed { peer }) => { - self.on_connection_closed(peer); - } - Some(TransportEvent::SubstreamOpened { - peer, - substream, - direction, - .. - }) => match direction { - Direction::Inbound => { - self.on_inbound_substream(peer, substream); - } - Direction::Outbound(substream_id) => { - match self.pending_opens.remove(&substream_id) { - Some(stored_peer) => { - debug_assert!(peer == stored_peer); - self.on_outbound_substream(peer, substream_id, substream); - } - None => { - todo!("substream {substream_id:?} does not exist"); - } - } - } - }, - Some(_) => {} - None => return, - }, - _event = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => {} - event = self.pending_outbound.next(), if !self.pending_outbound.is_empty() => { - match event { - Some(Ok((peer, elapsed))) => { - let _ = self - .tx - .send(PingEvent::Ping { - peer, - ping: elapsed, - }) - .await; - } - event => tracing::debug!(target: LOG_TARGET, "failed to handle ping for an outbound peer: {event:?}"), - } - } - } - } - } + /// Create new [`Ping`] protocol. + pub fn new(service: TransportService, config: Config) -> Self { + Self { + service, + tx: config.tx_event, + peers: HashSet::new(), + pending_opens: HashMap::new(), + pending_outbound: FuturesUnordered::new(), + pending_inbound: FuturesUnordered::new(), + _max_failures: config.max_failures, + } + } + + /// Connection established to remote peer. + fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, "connection established"); + + let substream_id = self.service.open_substream(peer)?; + self.pending_opens.insert(substream_id, peer); + self.peers.insert(peer); + + Ok(()) + } + + /// Connection closed to remote peer. + fn on_connection_closed(&mut self, peer: PeerId) { + tracing::trace!(target: LOG_TARGET, ?peer, "connection closed"); + + self.peers.remove(&peer); + } + + /// Handle outbound substream. + fn on_outbound_substream( + &mut self, + peer: PeerId, + substream_id: SubstreamId, + mut substream: Substream, + ) { + tracing::trace!(target: LOG_TARGET, ?peer, "handle outbound substream"); + + self.pending_outbound.push(Box::pin(async move { + let future = async move { + // TODO: generate random payload and verify it + let _ = substream.send_framed(vec![0u8; 32].into()).await?; + let now = Instant::now(); + let _ = substream.next().await.ok_or(Error::SubstreamError( + SubstreamError::ReadFailure(Some(substream_id)), + ))?; + let _ = substream.close().await; + + Ok(now.elapsed()) + }; + + match tokio::time::timeout(Duration::from_secs(10), future).await { + Err(_) => return Err(Error::Timeout), + Ok(Err(error)) => return Err(error), + Ok(Ok(elapsed)) => Ok((peer, elapsed)), + } + })); + } + + /// Substream opened to remote peer. + fn on_inbound_substream(&mut self, peer: PeerId, mut substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "handle inbound substream"); + + self.pending_inbound.push(Box::pin(async move { + let future = async move { + let payload = substream + .next() + .await + .ok_or(Error::SubstreamError(SubstreamError::ReadFailure(None)))??; + substream.send_framed(payload.freeze()).await?; + let _ = substream.next().await.map(|_| ()); + + Ok(()) + }; + + match tokio::time::timeout(Duration::from_secs(10), future).await { + Err(_) => return Err(Error::Timeout), + Ok(Err(error)) => return Err(error), + Ok(Ok(())) => Ok(()), + } + })); + } + + /// Start [`Ping`] event loop. + pub async fn run(mut self) { + tracing::debug!(target: LOG_TARGET, "starting ping event loop"); + + loop { + tokio::select! { + event = self.service.next() => match event { + Some(TransportEvent::ConnectionEstablished { peer, .. }) => { + let _ = self.on_connection_established(peer); + } + Some(TransportEvent::ConnectionClosed { peer }) => { + self.on_connection_closed(peer); + } + Some(TransportEvent::SubstreamOpened { + peer, + substream, + direction, + .. + }) => match direction { + Direction::Inbound => { + self.on_inbound_substream(peer, substream); + } + Direction::Outbound(substream_id) => { + match self.pending_opens.remove(&substream_id) { + Some(stored_peer) => { + debug_assert!(peer == stored_peer); + self.on_outbound_substream(peer, substream_id, substream); + } + None => { + todo!("substream {substream_id:?} does not exist"); + } + } + } + }, + Some(_) => {} + None => return, + }, + _event = self.pending_inbound.next(), if !self.pending_inbound.is_empty() => {} + event = self.pending_outbound.next(), if !self.pending_outbound.is_empty() => { + match event { + Some(Ok((peer, elapsed))) => { + let _ = self + .tx + .send(PingEvent::Ping { + peer, + ping: elapsed, + }) + .await; + } + event => tracing::debug!(target: LOG_TARGET, "failed to handle ping for an outbound peer: {event:?}"), + } + } + } + } + } } diff --git a/src/protocol/mdns.rs b/src/protocol/mdns.rs index 39157e34..5e7270ec 100644 --- a/src/protocol/mdns.rs +++ b/src/protocol/mdns.rs @@ -27,22 +27,22 @@ use futures::Stream; use multiaddr::Multiaddr; use rand::{distributions::Alphanumeric, Rng}; use simple_dns::{ - rdata::{RData, PTR, TXT}, - Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE, + rdata::{RData, PTR, TXT}, + Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE, }; use socket2::{Domain, Protocol, Socket, Type}; use tokio::{ - net::UdpSocket, - sync::mpsc::{channel, Sender}, + net::UdpSocket, + sync::mpsc::{channel, Sender}, }; use tokio_stream::wrappers::ReceiverStream; use std::{ - collections::HashSet, - net, - net::{IpAddr, Ipv4Addr, SocketAddr}, - sync::Arc, - time::Duration, + collections::HashSet, + net, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::Duration, }; /// Logging target for the file. @@ -60,384 +60,387 @@ const SERVICE_NAME: &str = "_p2p._udp.local"; /// Events emitted by mDNS. // #[derive(Debug, Clone)] pub enum MdnsEvent { - /// One or more addresses discovered. - Discovered(Vec), + /// One or more addresses discovered. + Discovered(Vec), } /// mDNS configuration. // #[derive(Debug)] pub struct Config { - /// How often the network should be queried for new peers. - query_interval: Duration, + /// How often the network should be queried for new peers. + query_interval: Duration, - /// TX channel for sending mDNS events to user. - tx: Sender, + /// TX channel for sending mDNS events to user. + tx: Sender, } impl Config { - /// Create new [`Config`]. - /// - /// Return the configuration and an event stream for receiving [`MdnsEvent`]s. - pub fn new( - query_interval: Duration, - ) -> (Self, Box + Send + Unpin>) { - let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); - (Self { query_interval, tx }, Box::new(ReceiverStream::new(rx))) - } + /// Create new [`Config`]. + /// + /// Return the configuration and an event stream for receiving [`MdnsEvent`]s. + pub fn new( + query_interval: Duration, + ) -> (Self, Box + Send + Unpin>) { + let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); + ( + Self { query_interval, tx }, + Box::new(ReceiverStream::new(rx)), + ) + } } /// Main mDNS object. pub(crate) struct Mdns { - /// UDP socket for multicast requests/responses. - socket: UdpSocket, + /// UDP socket for multicast requests/responses. + socket: UdpSocket, - /// Query interval. - query_interval: Duration, + /// Query interval. + query_interval: Duration, - /// TX channel for sending events to user. - event_tx: Sender, + /// TX channel for sending events to user. + event_tx: Sender, - /// Handle to `TransportManager`. - _transport_handle: TransportManagerHandle, + /// Handle to `TransportManager`. + _transport_handle: TransportManagerHandle, - // Username. - username: String, + // Username. + username: String, - /// Next query ID. - next_query_id: u16, + /// Next query ID. + next_query_id: u16, - /// Buffer for incoming messages. - receive_buffer: Vec, + /// Buffer for incoming messages. + receive_buffer: Vec, - /// Listen addresses. - listen_addresses: Vec>, + /// Listen addresses. + listen_addresses: Vec>, - /// Discovered addresses. - discovered: HashSet, + /// Discovered addresses. + discovered: HashSet, } impl Mdns { - /// Create new [`Mdns`]. - pub(crate) fn new( - _transport_handle: TransportManagerHandle, - config: Config, - listen_addresses: Vec, - ) -> crate::Result { - let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; - socket.set_reuse_address(true)?; - #[cfg(unix)] - socket.set_reuse_port(true)?; - socket.bind( - &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(), - )?; - socket.set_multicast_loop_v4(true)?; - socket.set_multicast_ttl_v4(255)?; - socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?; - socket.set_nonblocking(true)?; - - Ok(Self { - _transport_handle, - event_tx: config.tx, - next_query_id: 1337u16, - discovered: HashSet::new(), - query_interval: config.query_interval, - receive_buffer: vec![0u8; 4096], - username: rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(32) - .map(char::from) - .collect(), - socket: UdpSocket::from_std(net::UdpSocket::from(socket))?, - listen_addresses: listen_addresses - .into_iter() - .map(|address| format!("dnsaddr={address}").into()) - .collect(), - }) - } - - /// Get next query ID. - fn next_query_id(&mut self) -> u16 { - let query_id = self.next_query_id; - self.next_query_id += 1; - - query_id - } - - /// Send mDNS query on the network. - async fn on_outbound_request(&mut self) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, "send outbound query"); - - let mut packet = Packet::new_query(self.next_query_id()); - - packet.questions.push(Question { - qname: Name::new_unchecked(SERVICE_NAME), - qtype: QTYPE::TYPE(TYPE::PTR), - qclass: QCLASS::CLASS(CLASS::IN), - unicast_response: false, - }); - - self.socket - .send_to( - &packet.build_bytes_vec().expect("valid packet"), - (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT), - ) - .await - .map(|_| ()) - .map_err(From::from) - } - - /// Handle inbound query. - fn on_inbound_request(&self, packet: Packet) -> Option> { - tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request"); - - let mut packet = Packet::new_reply(packet.id()); - let srv_name = Name::new_unchecked(SERVICE_NAME); - - packet.answers.push(ResourceRecord::new( - srv_name.clone(), - CLASS::IN, - 360, - RData::PTR(PTR(Name::new_unchecked(&self.username))), - )); - - for address in &self.listen_addresses { - let mut record = TXT::new(); - record.add_string(address).expect("valid string"); - - packet.additional_records.push(ResourceRecord { - name: Name::new_unchecked(&self.username), - class: CLASS::IN, - ttl: 360, - rdata: RData::TXT(record), - cache_flush: false, - }); - } - - Some(packet.build_bytes_vec().expect("valid packet")) - } - - /// Handle inbound response. - fn on_inbound_response(&self, packet: Packet) -> Vec { - tracing::debug!(target: LOG_TARGET, "handle inbound response"); - - let names = packet - .answers - .iter() - .filter_map(|answer| { - if answer.name != Name::new_unchecked(SERVICE_NAME) { - return None; - } - - match answer.rdata { - RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) => - Some(name), - _ => None, - } - }) - .collect::>(); - - let name = match names.len() { - 0 => return Vec::new(), - _ => { - tracing::debug!( - target: LOG_TARGET, - ?names, - "response name" - ); - - names[0] - }, - }; - - packet - .additional_records - .iter() - .flat_map(|record| { - if &record.name != name { - return vec![]; - } - - // TODO: `filter_map` is not necessary as there's at most one entry - match &record.rdata { - RData::TXT(text) => text - .attributes() - .iter() - .filter_map(|(_, address)| { - address.as_ref().map_or(None, |inner| inner.parse().ok()) - }) - .collect(), - _ => vec![], - } - }) - .collect() - } - - /// Event loop for [`Mdns`]. - pub(crate) async fn start(mut self) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, "starting mdns event loop"); - - // before starting the loop, make an initial query to the network - // - // bail early if the socket is not working - self.on_outbound_request().await?; - - loop { - tokio::select! { - _ = tokio::time::sleep(self.query_interval) => { - tracing::trace!(target: LOG_TARGET, "timeout expired"); - - if let Err(error) = self.on_outbound_request().await { - tracing::error!(target: LOG_TARGET, ?error, "failed to send mdns query"); - return Err(error); - } - } - result = self.socket.recv_from(&mut self.receive_buffer) => match result { - Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) { - Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) { - true => { - let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| { - self.discovered.insert(address.clone()).then_some(address) - }) - .collect::>(); - - if !to_forward.is_empty() { - let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await; - } - } - false => if let Some(response) = self.on_inbound_request(packet) { - self.socket - .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT)) - .await?; - } - } - Err(error) => tracing::debug!( - target: LOG_TARGET, - ?address, - ?error, - ?nread, - "failed to parse mdns packet" - ), - } - Err(error) => { - tracing::error!(target: LOG_TARGET, ?error, "failed to read from socket"); - return Err(Error::from(error)); - } - }, - } - } - } + /// Create new [`Mdns`]. + pub(crate) fn new( + _transport_handle: TransportManagerHandle, + config: Config, + listen_addresses: Vec, + ) -> crate::Result { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_reuse_address(true)?; + #[cfg(unix)] + socket.set_reuse_port(true)?; + socket.bind( + &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(), + )?; + socket.set_multicast_loop_v4(true)?; + socket.set_multicast_ttl_v4(255)?; + socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?; + socket.set_nonblocking(true)?; + + Ok(Self { + _transport_handle, + event_tx: config.tx, + next_query_id: 1337u16, + discovered: HashSet::new(), + query_interval: config.query_interval, + receive_buffer: vec![0u8; 4096], + username: rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(), + socket: UdpSocket::from_std(net::UdpSocket::from(socket))?, + listen_addresses: listen_addresses + .into_iter() + .map(|address| format!("dnsaddr={address}").into()) + .collect(), + }) + } + + /// Get next query ID. + fn next_query_id(&mut self) -> u16 { + let query_id = self.next_query_id; + self.next_query_id += 1; + + query_id + } + + /// Send mDNS query on the network. + async fn on_outbound_request(&mut self) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, "send outbound query"); + + let mut packet = Packet::new_query(self.next_query_id()); + + packet.questions.push(Question { + qname: Name::new_unchecked(SERVICE_NAME), + qtype: QTYPE::TYPE(TYPE::PTR), + qclass: QCLASS::CLASS(CLASS::IN), + unicast_response: false, + }); + + self.socket + .send_to( + &packet.build_bytes_vec().expect("valid packet"), + (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT), + ) + .await + .map(|_| ()) + .map_err(From::from) + } + + /// Handle inbound query. + fn on_inbound_request(&self, packet: Packet) -> Option> { + tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request"); + + let mut packet = Packet::new_reply(packet.id()); + let srv_name = Name::new_unchecked(SERVICE_NAME); + + packet.answers.push(ResourceRecord::new( + srv_name.clone(), + CLASS::IN, + 360, + RData::PTR(PTR(Name::new_unchecked(&self.username))), + )); + + for address in &self.listen_addresses { + let mut record = TXT::new(); + record.add_string(address).expect("valid string"); + + packet.additional_records.push(ResourceRecord { + name: Name::new_unchecked(&self.username), + class: CLASS::IN, + ttl: 360, + rdata: RData::TXT(record), + cache_flush: false, + }); + } + + Some(packet.build_bytes_vec().expect("valid packet")) + } + + /// Handle inbound response. + fn on_inbound_response(&self, packet: Packet) -> Vec { + tracing::debug!(target: LOG_TARGET, "handle inbound response"); + + let names = packet + .answers + .iter() + .filter_map(|answer| { + if answer.name != Name::new_unchecked(SERVICE_NAME) { + return None; + } + + match answer.rdata { + RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) => + Some(name), + _ => None, + } + }) + .collect::>(); + + let name = match names.len() { + 0 => return Vec::new(), + _ => { + tracing::debug!( + target: LOG_TARGET, + ?names, + "response name" + ); + + names[0] + } + }; + + packet + .additional_records + .iter() + .flat_map(|record| { + if &record.name != name { + return vec![]; + } + + // TODO: `filter_map` is not necessary as there's at most one entry + match &record.rdata { + RData::TXT(text) => text + .attributes() + .iter() + .filter_map(|(_, address)| { + address.as_ref().map_or(None, |inner| inner.parse().ok()) + }) + .collect(), + _ => vec![], + } + }) + .collect() + } + + /// Event loop for [`Mdns`]. + pub(crate) async fn start(mut self) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, "starting mdns event loop"); + + // before starting the loop, make an initial query to the network + // + // bail early if the socket is not working + self.on_outbound_request().await?; + + loop { + tokio::select! { + _ = tokio::time::sleep(self.query_interval) => { + tracing::trace!(target: LOG_TARGET, "timeout expired"); + + if let Err(error) = self.on_outbound_request().await { + tracing::error!(target: LOG_TARGET, ?error, "failed to send mdns query"); + return Err(error); + } + } + result = self.socket.recv_from(&mut self.receive_buffer) => match result { + Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) { + Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) { + true => { + let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| { + self.discovered.insert(address.clone()).then_some(address) + }) + .collect::>(); + + if !to_forward.is_empty() { + let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await; + } + } + false => if let Some(response) = self.on_inbound_request(packet) { + self.socket + .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT)) + .await?; + } + } + Err(error) => tracing::debug!( + target: LOG_TARGET, + ?address, + ?error, + ?nread, + "failed to parse mdns packet" + ), + } + Err(error) => { + tracing::error!(target: LOG_TARGET, ?error, "failed to read from socket"); + return Err(Error::from(error)); + } + }, + } + } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{crypto::ed25519::Keypair, transport::manager::TransportManager, BandwidthSink}; - use futures::StreamExt; - use multiaddr::Protocol; - - #[tokio::test] - async fn mdns_works() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (config1, mut stream1) = Config::new(Duration::from_secs(5)); - let (_manager1, handle1) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - let mdns1 = Mdns::new( - handle1, - config1, - vec![ - "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa" - .parse() - .unwrap(), - "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa" - .parse() - .unwrap(), - ], - ) - .unwrap(); - - let (config2, mut stream2) = Config::new(Duration::from_secs(5)); - let (_manager1, handle2) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - let mdns2 = Mdns::new( - handle2, - config2, - vec![ - "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb" - .parse() - .unwrap(), - "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb" - .parse() - .unwrap(), - ], - ) - .unwrap(); - - tokio::spawn(mdns1.start()); - tokio::spawn(mdns2.start()); - - let mut peer1_discovered = false; - let mut peer2_discovered = false; - - while !peer1_discovered && !peer2_discovered { - tokio::select! { - event = stream1.next() => match event.unwrap() { - MdnsEvent::Discovered(addrs) => { - if addrs.len() == 2 { - let mut iter = addrs[0].iter(); - - if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) { - continue - } - - match iter.next() { - Some(Protocol::Tcp(port)) => { - if port != 9999 { - continue - } - } - _ => continue, - } - - peer1_discovered = true; - } - } - }, - event = stream2.next() => match event.unwrap() { - MdnsEvent::Discovered(addrs) => { - if addrs.len() == 2 { - let mut iter = addrs[0].iter(); - - if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) { - continue - } - - match iter.next() { - Some(Protocol::Tcp(port)) => { - if port != 8888 { - continue - } - } - _ => continue, - } - - peer2_discovered = true; - } - } - } - } - } - } + use super::*; + use crate::{crypto::ed25519::Keypair, transport::manager::TransportManager, BandwidthSink}; + use futures::StreamExt; + use multiaddr::Protocol; + + #[tokio::test] + async fn mdns_works() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (config1, mut stream1) = Config::new(Duration::from_secs(5)); + let (_manager1, handle1) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let mdns1 = Mdns::new( + handle1, + config1, + vec![ + "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa" + .parse() + .unwrap(), + "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa" + .parse() + .unwrap(), + ], + ) + .unwrap(); + + let (config2, mut stream2) = Config::new(Duration::from_secs(5)); + let (_manager1, handle2) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let mdns2 = Mdns::new( + handle2, + config2, + vec![ + "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb" + .parse() + .unwrap(), + "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb" + .parse() + .unwrap(), + ], + ) + .unwrap(); + + tokio::spawn(mdns1.start()); + tokio::spawn(mdns2.start()); + + let mut peer1_discovered = false; + let mut peer2_discovered = false; + + while !peer1_discovered && !peer2_discovered { + tokio::select! { + event = stream1.next() => match event.unwrap() { + MdnsEvent::Discovered(addrs) => { + if addrs.len() == 2 { + let mut iter = addrs[0].iter(); + + if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) { + continue + } + + match iter.next() { + Some(Protocol::Tcp(port)) => { + if port != 9999 { + continue + } + } + _ => continue, + } + + peer1_discovered = true; + } + } + }, + event = stream2.next() => match event.unwrap() { + MdnsEvent::Discovered(addrs) => { + if addrs.len() == 2 { + let mut iter = addrs[0].iter(); + + if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) { + continue + } + + match iter.next() { + Some(Protocol::Tcp(port)) => { + if port != 8888 { + continue + } + } + _ => continue, + } + + peer2_discovered = true; + } + } + } + } + } + } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 08f83521..94043904 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -21,12 +21,12 @@ //! Protocol-related defines. use crate::{ - codec::ProtocolCodec, - error::Error, - substream::Substream, - transport::Endpoint, - types::{protocol::ProtocolName, SubstreamId}, - PeerId, + codec::ProtocolCodec, + error::Error, + substream::Substream, + transport::Endpoint, + types::{protocol::ProtocolName, SubstreamId}, + PeerId, }; use multiaddr::Multiaddr; @@ -50,94 +50,94 @@ mod transport_service; /// Substream direction. #[derive(Debug, Copy, Clone)] pub enum Direction { - /// Substream was opened by the remote peer. - Inbound, + /// Substream was opened by the remote peer. + Inbound, - /// Substream was opened by the local peer. - Outbound(SubstreamId), + /// Substream was opened by the local peer. + Outbound(SubstreamId), } /// Events emitted by one of the installed transports to protocol(s). #[derive(Debug)] pub enum TransportEvent { - /// Connection established to `peer`. - ConnectionEstablished { - /// Peer ID. - peer: PeerId, - - /// Endpoint. - endpoint: Endpoint, - }, - - /// Connection closed to peer. - ConnectionClosed { - /// Peer ID. - peer: PeerId, - }, - - /// Failed to dial peer. - /// - /// This is reported to that protocol which initiated the connection. - DialFailure { - /// Peer ID. - peer: PeerId, - - /// Dialed address. - address: Multiaddr, - }, - - /// Substream opened for `peer`. - SubstreamOpened { - /// Peer ID. - peer: PeerId, - - /// Protocol name. - /// - /// One protocol handler may handle multiple sub-protocols (such as `/ipfs/identify/1.0.0` - /// and `/ipfs/identify/push/1.0.0`) or it may have aliases which should be handled by - /// the same protocol handler. When the substream is sent from transport to the protocol - /// handler, the protocol name that was used to negotiate the substream is also sent so - /// the protocol can handle the substream appropriately. - protocol: ProtocolName, - - /// Fallback protocol. - fallback: Option, - - /// Substream direction. - /// - /// Informs the protocol whether the substream is inbound (opened by the remote node) - /// or outbound (opened by the local node). This allows the protocol to distinguish - /// between the two types of substreams and execute correct code for the substream. - /// - /// Outbound substreams also contain the substream ID which allows the protocol to - /// distinguish between different outbound substreams. - direction: Direction, - - /// Substream. - substream: Substream, - }, - - /// Failed to open substream. - /// - /// Substream open failures are reported only for outbound substreams. - SubstreamOpenFailure { - /// Substream ID. - substream: SubstreamId, - - /// Error that occurred when the substream was being opened. - error: Error, - }, + /// Connection established to `peer`. + ConnectionEstablished { + /// Peer ID. + peer: PeerId, + + /// Endpoint. + endpoint: Endpoint, + }, + + /// Connection closed to peer. + ConnectionClosed { + /// Peer ID. + peer: PeerId, + }, + + /// Failed to dial peer. + /// + /// This is reported to that protocol which initiated the connection. + DialFailure { + /// Peer ID. + peer: PeerId, + + /// Dialed address. + address: Multiaddr, + }, + + /// Substream opened for `peer`. + SubstreamOpened { + /// Peer ID. + peer: PeerId, + + /// Protocol name. + /// + /// One protocol handler may handle multiple sub-protocols (such as `/ipfs/identify/1.0.0` + /// and `/ipfs/identify/push/1.0.0`) or it may have aliases which should be handled by + /// the same protocol handler. When the substream is sent from transport to the protocol + /// handler, the protocol name that was used to negotiate the substream is also sent so + /// the protocol can handle the substream appropriately. + protocol: ProtocolName, + + /// Fallback protocol. + fallback: Option, + + /// Substream direction. + /// + /// Informs the protocol whether the substream is inbound (opened by the remote node) + /// or outbound (opened by the local node). This allows the protocol to distinguish + /// between the two types of substreams and execute correct code for the substream. + /// + /// Outbound substreams also contain the substream ID which allows the protocol to + /// distinguish between different outbound substreams. + direction: Direction, + + /// Substream. + substream: Substream, + }, + + /// Failed to open substream. + /// + /// Substream open failures are reported only for outbound substreams. + SubstreamOpenFailure { + /// Substream ID. + substream: SubstreamId, + + /// Error that occurred when the substream was being opened. + error: Error, + }, } /// Trait defining the interface for a user protocol. #[async_trait::async_trait] pub trait UserProtocol: Send { - /// Get user protocol name. - fn protocol(&self) -> ProtocolName; + /// Get user protocol name. + fn protocol(&self) -> ProtocolName; - /// Get user protocol codec. - fn codec(&self) -> ProtocolCodec; + /// Get user protocol codec. + fn codec(&self) -> ProtocolCodec; - /// Start the the user protocol event loop. - async fn run(self: Box, service: TransportService) -> crate::Result<()>; + /// Start the the user protocol event loop. + async fn run(self: Box, service: TransportService) -> crate::Result<()>; } diff --git a/src/protocol/notification/config.rs b/src/protocol/notification/config.rs index 36109150..b9dedc14 100644 --- a/src/protocol/notification/config.rs +++ b/src/protocol/notification/config.rs @@ -19,15 +19,15 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, - protocol::notification::{ - handle::NotificationHandle, - types::{ - InnerNotificationEvent, NotificationCommand, ASYNC_CHANNEL_SIZE, SYNC_CHANNEL_SIZE, - }, - }, - types::protocol::ProtocolName, - PeerId, DEFAULT_CHANNEL_SIZE, + codec::ProtocolCodec, + protocol::notification::{ + handle::NotificationHandle, + types::{ + InnerNotificationEvent, NotificationCommand, ASYNC_CHANNEL_SIZE, SYNC_CHANNEL_SIZE, + }, + }, + types::protocol::ProtocolName, + PeerId, DEFAULT_CHANNEL_SIZE, }; use bytes::BytesMut; @@ -39,219 +39,219 @@ use std::sync::Arc; /// Notification configuration. #[derive(Debug)] pub struct Config { - /// Protocol name. - pub(crate) protocol_name: ProtocolName, + /// Protocol name. + pub(crate) protocol_name: ProtocolName, - /// Protocol codec. - pub(crate) codec: ProtocolCodec, + /// Protocol codec. + pub(crate) codec: ProtocolCodec, - /// Maximum notification size. - _max_notification_size: usize, + /// Maximum notification size. + _max_notification_size: usize, - /// Handshake bytes. - pub(crate) handshake: Arc>>, + /// Handshake bytes. + pub(crate) handshake: Arc>>, - /// Auto accept inbound substream. - pub(super) auto_accept: bool, + /// Auto accept inbound substream. + pub(super) auto_accept: bool, - /// Protocol aliases. - pub(crate) fallback_names: Vec, + /// Protocol aliases. + pub(crate) fallback_names: Vec, - /// TX channel passed to the protocol used for sending events. - pub(crate) event_tx: Sender, + /// TX channel passed to the protocol used for sending events. + pub(crate) event_tx: Sender, - /// TX channel for sending notifications from the connection handlers. - pub(crate) notif_tx: Sender<(PeerId, BytesMut)>, + /// TX channel for sending notifications from the connection handlers. + pub(crate) notif_tx: Sender<(PeerId, BytesMut)>, - /// RX channel passed to the protocol used for receiving commands. - pub(crate) command_rx: Receiver, + /// RX channel passed to the protocol used for receiving commands. + pub(crate) command_rx: Receiver, - /// Synchronous channel size. - pub(crate) sync_channel_size: usize, + /// Synchronous channel size. + pub(crate) sync_channel_size: usize, - /// Asynchronous channel size. - pub(crate) async_channel_size: usize, + /// Asynchronous channel size. + pub(crate) async_channel_size: usize, - /// Should `NotificationProtocol` dial the peer if there is no connection to them - /// when an outbound substream is requested. - pub(crate) should_dial: bool, + /// Should `NotificationProtocol` dial the peer if there is no connection to them + /// when an outbound substream is requested. + pub(crate) should_dial: bool, } impl Config { - /// Create new [`Config`]. - pub fn new( - protocol_name: ProtocolName, - max_notification_size: usize, - handshake: Vec, - fallback_names: Vec, - auto_accept: bool, - sync_channel_size: usize, - async_channel_size: usize, - should_dial: bool, - ) -> (Self, NotificationHandle) { - let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); - let (notif_tx, notif_rx) = channel(DEFAULT_CHANNEL_SIZE); - let (command_tx, command_rx) = channel(DEFAULT_CHANNEL_SIZE); - let handshake = Arc::new(RwLock::new(handshake)); - let handle = - NotificationHandle::new(event_rx, notif_rx, command_tx, Arc::clone(&handshake)); - - ( - Self { - protocol_name, - codec: ProtocolCodec::UnsignedVarint(Some(max_notification_size)), - _max_notification_size: max_notification_size, - auto_accept, - handshake, - fallback_names, - event_tx, - notif_tx, - command_rx, - should_dial, - sync_channel_size, - async_channel_size, - }, - handle, - ) - } - - /// Get protocol name. - pub(crate) fn protocol_name(&self) -> &ProtocolName { - &self.protocol_name - } - - /// Set handshake for the protocol. - /// - /// This function is used to work around an issue in Polkadot SDK and users - /// should not depend on its continued existence. - pub fn set_handshake(&mut self, handshake: Vec) { - let mut inner = self.handshake.write(); - *inner = handshake; - } + /// Create new [`Config`]. + pub fn new( + protocol_name: ProtocolName, + max_notification_size: usize, + handshake: Vec, + fallback_names: Vec, + auto_accept: bool, + sync_channel_size: usize, + async_channel_size: usize, + should_dial: bool, + ) -> (Self, NotificationHandle) { + let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); + let (notif_tx, notif_rx) = channel(DEFAULT_CHANNEL_SIZE); + let (command_tx, command_rx) = channel(DEFAULT_CHANNEL_SIZE); + let handshake = Arc::new(RwLock::new(handshake)); + let handle = + NotificationHandle::new(event_rx, notif_rx, command_tx, Arc::clone(&handshake)); + + ( + Self { + protocol_name, + codec: ProtocolCodec::UnsignedVarint(Some(max_notification_size)), + _max_notification_size: max_notification_size, + auto_accept, + handshake, + fallback_names, + event_tx, + notif_tx, + command_rx, + should_dial, + sync_channel_size, + async_channel_size, + }, + handle, + ) + } + + /// Get protocol name. + pub(crate) fn protocol_name(&self) -> &ProtocolName { + &self.protocol_name + } + + /// Set handshake for the protocol. + /// + /// This function is used to work around an issue in Polkadot SDK and users + /// should not depend on its continued existence. + pub fn set_handshake(&mut self, handshake: Vec) { + let mut inner = self.handshake.write(); + *inner = handshake; + } } /// Notification configuration builder. pub struct ConfigBuilder { - /// Protocol name. - protocol_name: ProtocolName, + /// Protocol name. + protocol_name: ProtocolName, - /// Maximum notification size. - max_notification_size: Option, + /// Maximum notification size. + max_notification_size: Option, - /// Handshake bytes. - handshake: Option>, + /// Handshake bytes. + handshake: Option>, - /// Should `NotificationProtocol` dial the peer if an outbound substream is requested but there - /// is no connection to the peer. - should_dial: bool, + /// Should `NotificationProtocol` dial the peer if an outbound substream is requested but there + /// is no connection to the peer. + should_dial: bool, - /// Fallback names. - fallback_names: Vec, + /// Fallback names. + fallback_names: Vec, - /// Auto accept inbound substream. - auto_accept_inbound_for_initiated: bool, + /// Auto accept inbound substream. + auto_accept_inbound_for_initiated: bool, - /// Synchronous channel size. - sync_channel_size: usize, + /// Synchronous channel size. + sync_channel_size: usize, - /// Asynchronous channel size. - async_channel_size: usize, + /// Asynchronous channel size. + async_channel_size: usize, } impl ConfigBuilder { - /// Create new [`ConfigBuilder`]. - pub fn new(protocol_name: ProtocolName) -> Self { - Self { - protocol_name, - max_notification_size: None, - handshake: None, - fallback_names: Vec::new(), - auto_accept_inbound_for_initiated: false, - sync_channel_size: SYNC_CHANNEL_SIZE, - async_channel_size: ASYNC_CHANNEL_SIZE, - should_dial: true, - } - } - - /// Set maximum notification size. - pub fn with_max_size(mut self, max_notification_size: usize) -> Self { - self.max_notification_size = Some(max_notification_size); - self - } - - /// Set handshake. - pub fn with_handshake(mut self, handshake: Vec) -> Self { - self.handshake = Some(handshake); - self - } - - /// Set fallback names. - pub fn with_fallback_names(mut self, fallback_names: Vec) -> Self { - self.fallback_names = fallback_names; - self - } - - /// Auto-accept inbound substreams for those connections which were initiated by the local - /// node. - /// - /// Connection in this context means a bidirectional substream pair between two peers over a - /// given protocol. - /// - /// By default, when a node starts a connection with a remote node and opens an outbound - /// substream to them, that substream is validated and if it's accepted, remote node sends - /// their handshake over that substream and opens another substream to local node. The - /// substream that was opened by the local node is used for sending data and the one opened - /// by the remote node is used for receiving data. - /// - /// By default, even if the local node was the one that opened the first substream, this inbound - /// substream coming from remote node must be validated as the handshake of the remote node - /// may reveal that it's not someone that the local node is willing to accept. - /// - /// To disable this behavior, auto accepting for the inbound substream can be enabled. If local - /// node is the one that opened the connection and it was accepted by the remote node, local - /// node is only notified via - /// [`NotificationStreamOpened`](super::types::NotificationEvent::NotificationStreamOpened). - pub fn with_auto_accept_inbound(mut self, auto_accept: bool) -> Self { - self.auto_accept_inbound_for_initiated = auto_accept; - self - } - - /// Configure size of the channel for sending synchronous notifications. - /// - /// Default value is `16`. - pub fn with_sync_channel_size(mut self, size: usize) -> Self { - self.sync_channel_size = size; - self - } - - /// Configure size of the channel for sending asynchronous notifications. - /// - /// Default value is `8`. - pub fn with_async_channel_size(mut self, size: usize) -> Self { - self.async_channel_size = size; - self - } - - /// Should `NotificationProtocol` attempt to dial the peer if an outbound substream is opened - /// but no connection to the peer exist. - /// - /// Dialing is enabled by default. - pub fn with_dialing_enabled(mut self, should_dial: bool) -> Self { - self.should_dial = should_dial; - self - } - - /// Build notification configuration. - pub fn build(mut self) -> (Config, NotificationHandle) { - Config::new( - self.protocol_name, - self.max_notification_size.take().expect("notification size to be specified"), - self.handshake.take().expect("handshake to be specified"), - self.fallback_names, - self.auto_accept_inbound_for_initiated, - self.sync_channel_size, - self.async_channel_size, - self.should_dial, - ) - } + /// Create new [`ConfigBuilder`]. + pub fn new(protocol_name: ProtocolName) -> Self { + Self { + protocol_name, + max_notification_size: None, + handshake: None, + fallback_names: Vec::new(), + auto_accept_inbound_for_initiated: false, + sync_channel_size: SYNC_CHANNEL_SIZE, + async_channel_size: ASYNC_CHANNEL_SIZE, + should_dial: true, + } + } + + /// Set maximum notification size. + pub fn with_max_size(mut self, max_notification_size: usize) -> Self { + self.max_notification_size = Some(max_notification_size); + self + } + + /// Set handshake. + pub fn with_handshake(mut self, handshake: Vec) -> Self { + self.handshake = Some(handshake); + self + } + + /// Set fallback names. + pub fn with_fallback_names(mut self, fallback_names: Vec) -> Self { + self.fallback_names = fallback_names; + self + } + + /// Auto-accept inbound substreams for those connections which were initiated by the local + /// node. + /// + /// Connection in this context means a bidirectional substream pair between two peers over a + /// given protocol. + /// + /// By default, when a node starts a connection with a remote node and opens an outbound + /// substream to them, that substream is validated and if it's accepted, remote node sends + /// their handshake over that substream and opens another substream to local node. The + /// substream that was opened by the local node is used for sending data and the one opened + /// by the remote node is used for receiving data. + /// + /// By default, even if the local node was the one that opened the first substream, this inbound + /// substream coming from remote node must be validated as the handshake of the remote node + /// may reveal that it's not someone that the local node is willing to accept. + /// + /// To disable this behavior, auto accepting for the inbound substream can be enabled. If local + /// node is the one that opened the connection and it was accepted by the remote node, local + /// node is only notified via + /// [`NotificationStreamOpened`](super::types::NotificationEvent::NotificationStreamOpened). + pub fn with_auto_accept_inbound(mut self, auto_accept: bool) -> Self { + self.auto_accept_inbound_for_initiated = auto_accept; + self + } + + /// Configure size of the channel for sending synchronous notifications. + /// + /// Default value is `16`. + pub fn with_sync_channel_size(mut self, size: usize) -> Self { + self.sync_channel_size = size; + self + } + + /// Configure size of the channel for sending asynchronous notifications. + /// + /// Default value is `8`. + pub fn with_async_channel_size(mut self, size: usize) -> Self { + self.async_channel_size = size; + self + } + + /// Should `NotificationProtocol` attempt to dial the peer if an outbound substream is opened + /// but no connection to the peer exist. + /// + /// Dialing is enabled by default. + pub fn with_dialing_enabled(mut self, should_dial: bool) -> Self { + self.should_dial = should_dial; + self + } + + /// Build notification configuration. + pub fn build(mut self) -> (Config, NotificationHandle) { + Config::new( + self.protocol_name, + self.max_notification_size.take().expect("notification size to be specified"), + self.handshake.take().expect("handshake to be specified"), + self.fallback_names, + self.auto_accept_inbound_for_initiated, + self.sync_channel_size, + self.async_channel_size, + self.should_dial, + ) + } } diff --git a/src/protocol/notification/connection.rs b/src/protocol/notification/connection.rs index 4ae09e95..819f305d 100644 --- a/src/protocol/notification/connection.rs +++ b/src/protocol/notification/connection.rs @@ -19,20 +19,20 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::notification::handle::NotificationEventHandle, substream::Substream, PeerId, + protocol::notification::handle::NotificationEventHandle, substream::Substream, PeerId, }; use bytes::BytesMut; use futures::{FutureExt, SinkExt, Stream, StreamExt}; use tokio::sync::{ - mpsc::{Receiver, Sender}, - oneshot, + mpsc::{Receiver, Sender}, + oneshot, }; use tokio_util::sync::PollSender; use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; /// Logging target for the file. @@ -40,230 +40,233 @@ const LOG_TARGET: &str = "litep2p::notification::connection"; /// Bidirectional substream pair representing a connection to a remote peer. pub(crate) struct Connection { - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Inbound substreams for receiving notifications. - inbound: Substream, + /// Inbound substreams for receiving notifications. + inbound: Substream, - /// Outbound substream for sending notifications. - outbound: Substream, + /// Outbound substream for sending notifications. + outbound: Substream, - /// Handle for sending notification events to user. - event_handle: NotificationEventHandle, + /// Handle for sending notification events to user. + event_handle: NotificationEventHandle, - /// TX channel used to notify [`NotificationProtocol`](super::NotificationProtocol) - /// that the connection has been closed. - conn_closed_tx: Sender, + /// TX channel used to notify [`NotificationProtocol`](super::NotificationProtocol) + /// that the connection has been closed. + conn_closed_tx: Sender, - /// TX channel for sending notifications. - notif_tx: PollSender<(PeerId, BytesMut)>, + /// TX channel for sending notifications. + notif_tx: PollSender<(PeerId, BytesMut)>, - /// Receiver for asynchronously sent notifications. - async_rx: Receiver>, + /// Receiver for asynchronously sent notifications. + async_rx: Receiver>, - /// Receiver for synchronously sent notifications. - sync_rx: Receiver>, + /// Receiver for synchronously sent notifications. + sync_rx: Receiver>, - /// Oneshot receiver used by [`NotificationProtocol`](super::NotificationProtocol) - /// to signal that local node wishes the close the connection. - rx: oneshot::Receiver<()>, + /// Oneshot receiver used by [`NotificationProtocol`](super::NotificationProtocol) + /// to signal that local node wishes the close the connection. + rx: oneshot::Receiver<()>, - /// Next notification to send, if any. - next_notification: Option>, + /// Next notification to send, if any. + next_notification: Option>, } /// Notify [`NotificationProtocol`](super::NotificationProtocol) that the connection was closed. #[derive(Debug)] pub enum NotifyProtocol { - /// Notify the protocol handler. - Yes, + /// Notify the protocol handler. + Yes, - /// Do not notify protocol handler. - No, + /// Do not notify protocol handler. + No, } impl Connection { - /// Create new [`Connection`]. - pub(crate) fn new( - peer: PeerId, - inbound: Substream, - outbound: Substream, - event_handle: NotificationEventHandle, - conn_closed_tx: Sender, - notif_tx: Sender<(PeerId, BytesMut)>, - async_rx: Receiver>, - sync_rx: Receiver>, - ) -> (Self, oneshot::Sender<()>) { - let (tx, rx) = oneshot::channel(); - - ( - Self { - rx, - peer, - sync_rx, - async_rx, - inbound, - outbound, - event_handle, - conn_closed_tx, - next_notification: None, - notif_tx: PollSender::new(notif_tx), - }, - tx, - ) - } - - /// Connection closed, clean up state. - /// - /// If [`NotificationProtocol`](super::NotificationProtocol) was the one that initiated - /// shut down, it's not notified of connection getting closed. - async fn close_connection(self, notify_protocol: NotifyProtocol) { - tracing::trace!( - target: LOG_TARGET, - peer = ?self.peer, - ?notify_protocol, - "close notification protocol", - ); - - let _ = self.inbound.close().await; - let _ = self.outbound.close().await; - - if std::matches!(notify_protocol, NotifyProtocol::Yes) { - let _ = self.conn_closed_tx.send(self.peer).await; - } - - self.event_handle.report_notification_stream_closed(self.peer).await; - } - - pub async fn start(mut self) { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - "start connection event loop", - ); - - loop { - match self.next().await { - None | Some(ConnectionEvent::CloseConnection { notify: NotifyProtocol::Yes }) => - return self.close_connection(NotifyProtocol::Yes).await, - Some(ConnectionEvent::CloseConnection { notify: NotifyProtocol::No }) => - return self.close_connection(NotifyProtocol::No).await, - Some(ConnectionEvent::NotificationReceived { notification }) => { - if let Err(_) = self.notif_tx.send_item((self.peer, notification)) { - return self.close_connection(NotifyProtocol::Yes).await; - } - }, - } - } - } + /// Create new [`Connection`]. + pub(crate) fn new( + peer: PeerId, + inbound: Substream, + outbound: Substream, + event_handle: NotificationEventHandle, + conn_closed_tx: Sender, + notif_tx: Sender<(PeerId, BytesMut)>, + async_rx: Receiver>, + sync_rx: Receiver>, + ) -> (Self, oneshot::Sender<()>) { + let (tx, rx) = oneshot::channel(); + + ( + Self { + rx, + peer, + sync_rx, + async_rx, + inbound, + outbound, + event_handle, + conn_closed_tx, + next_notification: None, + notif_tx: PollSender::new(notif_tx), + }, + tx, + ) + } + + /// Connection closed, clean up state. + /// + /// If [`NotificationProtocol`](super::NotificationProtocol) was the one that initiated + /// shut down, it's not notified of connection getting closed. + async fn close_connection(self, notify_protocol: NotifyProtocol) { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + ?notify_protocol, + "close notification protocol", + ); + + let _ = self.inbound.close().await; + let _ = self.outbound.close().await; + + if std::matches!(notify_protocol, NotifyProtocol::Yes) { + let _ = self.conn_closed_tx.send(self.peer).await; + } + + self.event_handle.report_notification_stream_closed(self.peer).await; + } + + pub async fn start(mut self) { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + "start connection event loop", + ); + + loop { + match self.next().await { + None + | Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + }) => return self.close_connection(NotifyProtocol::Yes).await, + Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::No, + }) => return self.close_connection(NotifyProtocol::No).await, + Some(ConnectionEvent::NotificationReceived { notification }) => { + if let Err(_) = self.notif_tx.send_item((self.peer, notification)) { + return self.close_connection(NotifyProtocol::Yes).await; + } + } + } + } + } } /// Connection events. pub enum ConnectionEvent { - /// Close connection. - /// - /// If `NotificationProtocol` requested [`Connection`] to be closed, it doesn't need to be - /// notified. If, on the other hand, connection closes because it encountered an error or one - /// of the substreams was closed, `NotificationProtocol` must be informed so it can inform the - /// user. - CloseConnection { - /// Whether to notify `NotificationProtocol` or not. - notify: NotifyProtocol, - }, - - /// Notification read from the inbound substream. - /// - /// NOTE: [`Connection`] uses `PollSender::send_item()` to send the notification to user. - /// `PollSender::poll_reserve()` must be called before calling `PollSender::send_item()` or it - /// will panic. `PollSender::poll_reserve()` is called in the `Stream` implementation below - /// before polling the inbound substream to ensure the channel has capacity to receive a - /// notification. - NotificationReceived { - /// Notification. - notification: BytesMut, - }, + /// Close connection. + /// + /// If `NotificationProtocol` requested [`Connection`] to be closed, it doesn't need to be + /// notified. If, on the other hand, connection closes because it encountered an error or one + /// of the substreams was closed, `NotificationProtocol` must be informed so it can inform the + /// user. + CloseConnection { + /// Whether to notify `NotificationProtocol` or not. + notify: NotifyProtocol, + }, + + /// Notification read from the inbound substream. + /// + /// NOTE: [`Connection`] uses `PollSender::send_item()` to send the notification to user. + /// `PollSender::poll_reserve()` must be called before calling `PollSender::send_item()` or it + /// will panic. `PollSender::poll_reserve()` is called in the `Stream` implementation below + /// before polling the inbound substream to ensure the channel has capacity to receive a + /// notification. + NotificationReceived { + /// Notification. + notification: BytesMut, + }, } impl Stream for Connection { - type Item = ConnectionEvent; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - - if let Poll::Ready(_) = this.rx.poll_unpin(cx) { - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::No, - })); - } - - loop { - let notification = match this.next_notification.take() { - Some(notification) => Some(notification), - None => { - let future = async { - tokio::select! { - notification = this.async_rx.recv() => notification, - notification = this.sync_rx.recv() => notification, - } - }; - futures::pin_mut!(future); - - match future.poll_unpin(cx) { - Poll::Pending => None, - Poll::Ready(None) => - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::Yes, - })), - Poll::Ready(Some(notification)) => Some(notification), - } - }, - }; - - let Some(notification) = notification else { - break; - }; - - match this.outbound.poll_ready_unpin(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Pending => { - this.next_notification = Some(notification); - break; - }, - Poll::Ready(Err(_)) => - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::Yes, - })), - } - - if let Err(_) = this.outbound.start_send_unpin(notification.into()) { - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::Yes, - })); - } - } - - match this.outbound.poll_flush_unpin(cx) { - Poll::Ready(Err(_)) => - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::Yes, - })), - Poll::Ready(Ok(())) | Poll::Pending => {}, - } - - if let Err(_) = futures::ready!(this.notif_tx.poll_reserve(cx)) { - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::Yes, - })); - } - - match futures::ready!(this.inbound.poll_next_unpin(cx)) { - None | Some(Err(_)) => - return Poll::Ready(Some(ConnectionEvent::CloseConnection { - notify: NotifyProtocol::Yes, - })), - Some(Ok(notification)) => - return Poll::Ready(Some(ConnectionEvent::NotificationReceived { notification })), - } - } + type Item = ConnectionEvent; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + if let Poll::Ready(_) = this.rx.poll_unpin(cx) { + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::No, + })); + } + + loop { + let notification = match this.next_notification.take() { + Some(notification) => Some(notification), + None => { + let future = async { + tokio::select! { + notification = this.async_rx.recv() => notification, + notification = this.sync_rx.recv() => notification, + } + }; + futures::pin_mut!(future); + + match future.poll_unpin(cx) { + Poll::Pending => None, + Poll::Ready(None) => + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + })), + Poll::Ready(Some(notification)) => Some(notification), + } + } + }; + + let Some(notification) = notification else { + break; + }; + + match this.outbound.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => {} + Poll::Pending => { + this.next_notification = Some(notification); + break; + } + Poll::Ready(Err(_)) => + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + })), + } + + if let Err(_) = this.outbound.start_send_unpin(notification.into()) { + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + })); + } + } + + match this.outbound.poll_flush_unpin(cx) { + Poll::Ready(Err(_)) => + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + })), + Poll::Ready(Ok(())) | Poll::Pending => {} + } + + if let Err(_) = futures::ready!(this.notif_tx.poll_reserve(cx)) { + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + })); + } + + match futures::ready!(this.inbound.poll_next_unpin(cx)) { + None | Some(Err(_)) => + return Poll::Ready(Some(ConnectionEvent::CloseConnection { + notify: NotifyProtocol::Yes, + })), + Some(Ok(notification)) => + return Poll::Ready(Some(ConnectionEvent::NotificationReceived { notification })), + } + } } diff --git a/src/protocol/notification/handle.rs b/src/protocol/notification/handle.rs index c678d109..31fd03fc 100644 --- a/src/protocol/notification/handle.rs +++ b/src/protocol/notification/handle.rs @@ -19,28 +19,28 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - error::Error, - protocol::notification::types::{ - Direction, InnerNotificationEvent, NotificationCommand, NotificationError, - NotificationEvent, ValidationResult, - }, - types::protocol::ProtocolName, - PeerId, + error::Error, + protocol::notification::types::{ + Direction, InnerNotificationEvent, NotificationCommand, NotificationError, + NotificationEvent, ValidationResult, + }, + types::protocol::ProtocolName, + PeerId, }; use bytes::BytesMut; use futures::Stream; use parking_lot::RwLock; use tokio::sync::{ - mpsc::{error::TrySendError, Receiver, Sender}, - oneshot, + mpsc::{error::TrySendError, Receiver, Sender}, + oneshot, }; use std::{ - collections::{HashMap, HashSet}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, + collections::{HashMap, HashSet}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; /// Logging target for the file. @@ -48,75 +48,75 @@ const LOG_TARGET: &str = "litep2p::notification::handle"; #[derive(Debug, Clone)] pub(crate) struct NotificationEventHandle { - tx: Sender, + tx: Sender, } impl NotificationEventHandle { - /// Create new [`NotificationEventHandle`]. - pub(crate) fn new(tx: Sender) -> Self { - Self { tx } - } - - /// Validate inbound substream. - pub(crate) async fn report_inbound_substream( - &self, - protocol: ProtocolName, - fallback: Option, - peer: PeerId, - handshake: Vec, - tx: oneshot::Sender, - ) { - let _ = self - .tx - .send(InnerNotificationEvent::ValidateSubstream { - protocol, - fallback, - peer, - handshake, - tx, - }) - .await; - } - - /// Notification stream opened. - pub(crate) async fn report_notification_stream_opened( - &self, - protocol: ProtocolName, - fallback: Option, - direction: Direction, - peer: PeerId, - handshake: Vec, - sink: NotificationSink, - ) { - let _ = self - .tx - .send(InnerNotificationEvent::NotificationStreamOpened { - protocol, - fallback, - direction, - peer, - handshake, - sink, - }) - .await; - } - - /// Notification stream closed. - pub(crate) async fn report_notification_stream_closed(&self, peer: PeerId) { - let _ = self.tx.send(InnerNotificationEvent::NotificationStreamClosed { peer }).await; - } - - /// Failed to open notification stream. - pub(crate) async fn report_notification_stream_open_failure( - &self, - peer: PeerId, - error: NotificationError, - ) { - let _ = self - .tx - .send(InnerNotificationEvent::NotificationStreamOpenFailure { peer, error }) - .await; - } + /// Create new [`NotificationEventHandle`]. + pub(crate) fn new(tx: Sender) -> Self { + Self { tx } + } + + /// Validate inbound substream. + pub(crate) async fn report_inbound_substream( + &self, + protocol: ProtocolName, + fallback: Option, + peer: PeerId, + handshake: Vec, + tx: oneshot::Sender, + ) { + let _ = self + .tx + .send(InnerNotificationEvent::ValidateSubstream { + protocol, + fallback, + peer, + handshake, + tx, + }) + .await; + } + + /// Notification stream opened. + pub(crate) async fn report_notification_stream_opened( + &self, + protocol: ProtocolName, + fallback: Option, + direction: Direction, + peer: PeerId, + handshake: Vec, + sink: NotificationSink, + ) { + let _ = self + .tx + .send(InnerNotificationEvent::NotificationStreamOpened { + protocol, + fallback, + direction, + peer, + handshake, + sink, + }) + .await; + } + + /// Notification stream closed. + pub(crate) async fn report_notification_stream_closed(&self, peer: PeerId) { + let _ = self.tx.send(InnerNotificationEvent::NotificationStreamClosed { peer }).await; + } + + /// Failed to open notification stream. + pub(crate) async fn report_notification_stream_open_failure( + &self, + peer: PeerId, + error: NotificationError, + ) { + let _ = self + .tx + .send(InnerNotificationEvent::NotificationStreamOpenFailure { peer, error }) + .await; + } } /// Notification sink. @@ -124,382 +124,393 @@ impl NotificationEventHandle { /// Allows the user to send notifications both synchronously and asynchronously. #[derive(Debug, Clone)] pub struct NotificationSink { - /// Peer ID. - peer: PeerId, + /// Peer ID. + peer: PeerId, - /// TX channel for sending notifications synchronously. - sync_tx: Sender>, + /// TX channel for sending notifications synchronously. + sync_tx: Sender>, - /// TX channel for sending notifications asynchronously. - async_tx: Sender>, + /// TX channel for sending notifications asynchronously. + async_tx: Sender>, } impl NotificationSink { - /// Create new [`NotificationSink`]. - pub(crate) fn new(peer: PeerId, sync_tx: Sender>, async_tx: Sender>) -> Self { - Self { peer, async_tx, sync_tx } - } - - /// Send notification to `peer` synchronously. - /// - /// If the channel is clogged, [`NotificationError::ChannelClogged`] is returned. - pub fn send_sync_notification(&self, notification: Vec) -> Result<(), NotificationError> { - self.sync_tx.try_send(notification).map_err(|error| match error { - TrySendError::Closed(_) => NotificationError::NoConnection, - TrySendError::Full(_) => NotificationError::ChannelClogged, - }) - } - - /// Send notification to `peer` asynchronously, waiting for the channel to have capacity - /// if it's clogged. - /// - /// Returns [`Error::PeerDoesntExist(PeerId)`](crate::error::Error::PeerDoesntExist) - /// if the connection has been closed. - pub async fn send_async_notification(&self, notification: Vec) -> crate::Result<()> { - self.async_tx - .send(notification) - .await - .map_err(|_| Error::PeerDoesntExist(self.peer)) - } + /// Create new [`NotificationSink`]. + pub(crate) fn new(peer: PeerId, sync_tx: Sender>, async_tx: Sender>) -> Self { + Self { + peer, + async_tx, + sync_tx, + } + } + + /// Send notification to `peer` synchronously. + /// + /// If the channel is clogged, [`NotificationError::ChannelClogged`] is returned. + pub fn send_sync_notification(&self, notification: Vec) -> Result<(), NotificationError> { + self.sync_tx.try_send(notification).map_err(|error| match error { + TrySendError::Closed(_) => NotificationError::NoConnection, + TrySendError::Full(_) => NotificationError::ChannelClogged, + }) + } + + /// Send notification to `peer` asynchronously, waiting for the channel to have capacity + /// if it's clogged. + /// + /// Returns [`Error::PeerDoesntExist(PeerId)`](crate::error::Error::PeerDoesntExist) + /// if the connection has been closed. + pub async fn send_async_notification(&self, notification: Vec) -> crate::Result<()> { + self.async_tx + .send(notification) + .await + .map_err(|_| Error::PeerDoesntExist(self.peer)) + } } /// Handle allowing the user protocol to interact with the notification protocol. #[derive(Debug)] pub struct NotificationHandle { - /// RX channel for receiving events from the notification protocol. - event_rx: Receiver, + /// RX channel for receiving events from the notification protocol. + event_rx: Receiver, - /// RX channel for receiving notifications from connection handlers. - notif_rx: Receiver<(PeerId, BytesMut)>, + /// RX channel for receiving notifications from connection handlers. + notif_rx: Receiver<(PeerId, BytesMut)>, - /// TX channel for sending commands to the notification protocol. - command_tx: Sender, + /// TX channel for sending commands to the notification protocol. + command_tx: Sender, - /// Peers. - peers: HashMap, + /// Peers. + peers: HashMap, - /// Clogged peers. - clogged: HashSet, + /// Clogged peers. + clogged: HashSet, - /// Pending validations. - pending_validations: HashMap>, + /// Pending validations. + pending_validations: HashMap>, - /// Handshake. - handshake: Arc>>, + /// Handshake. + handshake: Arc>>, } impl NotificationHandle { - /// Create new [`NotificationHandle`]. - pub(crate) fn new( - event_rx: Receiver, - notif_rx: Receiver<(PeerId, BytesMut)>, - command_tx: Sender, - handshake: Arc>>, - ) -> Self { - Self { - event_rx, - notif_rx, - command_tx, - handshake, - peers: HashMap::new(), - clogged: HashSet::new(), - pending_validations: HashMap::new(), - } - } - - /// Open substream to `peer`. - /// - /// Returns [`Error::PeerAlreadyExists(PeerId)`](crate::error::Error::PeerAlreadyExists) if - /// substream is already open to `peer`. - /// - /// If connection to peer is closed, `NotificationProtocol` tries to dial the peer and if the - /// dial succeeds, tries to open a substream. This behavior can be disabled with - /// [`ConfigBuilder::with_dialing_enabled(false)`](super::config::ConfigBuilder::with_dialing_enabled()). - pub async fn open_substream(&self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, "open substream"); - - if self.peers.contains_key(&peer) { - return Err(Error::PeerAlreadyExists(peer)); - } - - self.command_tx - .send(NotificationCommand::OpenSubstream { peers: HashSet::from_iter([peer]) }) - .await - .map_or(Ok(()), |_| Ok(())) - } - - /// Open substreams to multiple peers. - /// - /// Similar to [`NotificationHandle::open_substream()`] but multiple substreams are initiated - /// using a single call to `NotificationProtocol`. - /// - /// Peers who are already connected are ignored and returned as `Err(HashSet>)`. - pub async fn open_substream_batch( - &self, - peers: impl Iterator, - ) -> Result<(), HashSet> { - let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers - .map(|peer| match self.peers.contains_key(&peer) { - true => (None, Some(peer)), - false => (Some(peer), None), - }) - .unzip(); - - let to_add = to_add.into_iter().flatten().collect::>(); - let to_ignore = to_ignore.into_iter().flatten().collect::>(); - - tracing::trace!( - target: LOG_TARGET, - peers_to_add = ?to_add.len(), - peers_to_ignore = ?to_ignore.len(), - "open substream", - ); - - let _ = self.command_tx.send(NotificationCommand::OpenSubstream { peers: to_add }).await; - - match to_ignore.is_empty() { - true => Ok(()), - false => Err(to_ignore), - } - } - - /// Try to open substreams to multiple peers. - /// - /// Similar to [`NotificationHandle::open_substream()`] but multiple substreams are initiated - /// using a single call to `NotificationProtocol`. - /// - /// If the channel is clogged, peers for whom a connection is not yet open are returned as - /// `Err(HashSet)`. - pub fn try_open_substream_batch( - &self, - peers: impl Iterator, - ) -> Result<(), HashSet> { - let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers - .map(|peer| match self.peers.contains_key(&peer) { - true => (None, Some(peer)), - false => (Some(peer), None), - }) - .unzip(); - - let to_add = to_add.into_iter().flatten().collect::>(); - let to_ignore = to_ignore.into_iter().flatten().collect::>(); - - tracing::trace!( - target: LOG_TARGET, - peers_to_add = ?to_add.len(), - peers_to_ignore = ?to_ignore.len(), - "open substream", - ); - - self.command_tx - .try_send(NotificationCommand::OpenSubstream { peers: to_add.clone() }) - .map_err(|_| to_add) - } - - /// Close substream to `peer`. - pub async fn close_substream(&self, peer: PeerId) { - tracing::trace!(target: LOG_TARGET, ?peer, "close substream"); - - if !self.peers.contains_key(&peer) { - return; - } - - let _ = self - .command_tx - .send(NotificationCommand::CloseSubstream { peers: HashSet::from_iter([peer]) }) - .await; - } - - /// Close substream to multiple peers. - /// - /// Similar to [`NotificationHandle::close_substream()`] but multiple substreams are closed - /// using a single call to `NotificationProtocol`. - pub async fn close_substream_batch(&self, peers: impl Iterator) { - let peers = peers - .filter_map(|peer| self.peers.contains_key(&peer).then_some(peer)) - .collect::>(); - - if peers.is_empty() { - return; - } - - tracing::trace!( - target: LOG_TARGET, - ?peers, - "close substreams", - ); - - let _ = self.command_tx.send(NotificationCommand::CloseSubstream { peers }).await; - } - - /// Try close substream to multiple peers. - /// - /// Similar to [`NotificationHandle::close_substream()`] but multiple substreams are closed - /// using a single call to `NotificationProtocol`. - /// - /// If the channel is clogged, `peers` is returned as `Err(HashSet)`. - /// - /// If `peers` is empty after filtering all already-connected peers, - /// `Err(HashMap::new())` is returned. - pub fn try_close_substream_batch( - &self, - peers: impl Iterator, - ) -> Result<(), HashSet> { - let peers = peers - .filter_map(|peer| self.peers.contains_key(&peer).then_some(peer)) - .collect::>(); - - if peers.is_empty() { - return Err(HashSet::new()); - } - - tracing::trace!( - target: LOG_TARGET, - ?peers, - "close substreams", - ); - - self.command_tx - .try_send(NotificationCommand::CloseSubstream { peers: peers.clone() }) - .map_err(|_| peers) - } - - /// Set new handshake. - pub fn set_handshake(&mut self, handshake: Vec) { - tracing::trace!(target: LOG_TARGET, ?handshake, "set handshake"); - - *self.handshake.write() = handshake; - } - - /// Send validation result to the notification protocol for an inbound substream received from - /// `peer`. - pub fn send_validation_result(&mut self, peer: PeerId, result: ValidationResult) { - tracing::trace!(target: LOG_TARGET, ?peer, ?result, "send validation result"); - - self.pending_validations.remove(&peer).map(|tx| tx.send(result)); - } - - /// Send notification to `peer` synchronously. - /// - /// If the channel is clogged, [`NotificationError::ChannelClogged`] is returned. - pub fn send_sync_notification( - &mut self, - peer: PeerId, - notification: Vec, - ) -> Result<(), NotificationError> { - match self.peers.get_mut(&peer) { - Some(sink) => match sink.send_sync_notification(notification) { - Ok(()) => Ok(()), - Err(error) => match error { - NotificationError::NoConnection => return Err(NotificationError::NoConnection), - NotificationError::ChannelClogged => { - let _ = self.clogged.insert(peer).then(|| { - self.command_tx.try_send(NotificationCommand::ForceClose { peer }) - }); - - Err(NotificationError::ChannelClogged) - }, - // sink doesn't emit any other `NotificationError`s - _ => unreachable!(), - }, - }, - None => Ok(()), - } - } - - /// Send notification to `peer` asynchronously, waiting for the channel to have capacity - /// if it's clogged. - /// - /// Returns [`Error::PeerDoesntExist(PeerId)`](crate::error::Error::PeerDoesntExist) if the - /// connection has been closed. - pub async fn send_async_notification( - &mut self, - peer: PeerId, - notification: Vec, - ) -> crate::Result<()> { - match self.peers.get_mut(&peer) { - Some(sink) => sink.send_async_notification(notification).await, - None => Err(Error::PeerDoesntExist(peer)), - } - } - - /// Get a copy of the underlying notification sink for the peer. - /// - /// `None` is returned if `peer` doesn't exist. - pub fn notification_sink(&self, peer: PeerId) -> Option { - self.peers.get(&peer).and_then(|sink| Some(sink.clone())) - } + /// Create new [`NotificationHandle`]. + pub(crate) fn new( + event_rx: Receiver, + notif_rx: Receiver<(PeerId, BytesMut)>, + command_tx: Sender, + handshake: Arc>>, + ) -> Self { + Self { + event_rx, + notif_rx, + command_tx, + handshake, + peers: HashMap::new(), + clogged: HashSet::new(), + pending_validations: HashMap::new(), + } + } + + /// Open substream to `peer`. + /// + /// Returns [`Error::PeerAlreadyExists(PeerId)`](crate::error::Error::PeerAlreadyExists) if + /// substream is already open to `peer`. + /// + /// If connection to peer is closed, `NotificationProtocol` tries to dial the peer and if the + /// dial succeeds, tries to open a substream. This behavior can be disabled with + /// [`ConfigBuilder::with_dialing_enabled(false)`](super::config::ConfigBuilder::with_dialing_enabled()). + pub async fn open_substream(&self, peer: PeerId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, "open substream"); + + if self.peers.contains_key(&peer) { + return Err(Error::PeerAlreadyExists(peer)); + } + + self.command_tx + .send(NotificationCommand::OpenSubstream { + peers: HashSet::from_iter([peer]), + }) + .await + .map_or(Ok(()), |_| Ok(())) + } + + /// Open substreams to multiple peers. + /// + /// Similar to [`NotificationHandle::open_substream()`] but multiple substreams are initiated + /// using a single call to `NotificationProtocol`. + /// + /// Peers who are already connected are ignored and returned as `Err(HashSet>)`. + pub async fn open_substream_batch( + &self, + peers: impl Iterator, + ) -> Result<(), HashSet> { + let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers + .map(|peer| match self.peers.contains_key(&peer) { + true => (None, Some(peer)), + false => (Some(peer), None), + }) + .unzip(); + + let to_add = to_add.into_iter().flatten().collect::>(); + let to_ignore = to_ignore.into_iter().flatten().collect::>(); + + tracing::trace!( + target: LOG_TARGET, + peers_to_add = ?to_add.len(), + peers_to_ignore = ?to_ignore.len(), + "open substream", + ); + + let _ = self.command_tx.send(NotificationCommand::OpenSubstream { peers: to_add }).await; + + match to_ignore.is_empty() { + true => Ok(()), + false => Err(to_ignore), + } + } + + /// Try to open substreams to multiple peers. + /// + /// Similar to [`NotificationHandle::open_substream()`] but multiple substreams are initiated + /// using a single call to `NotificationProtocol`. + /// + /// If the channel is clogged, peers for whom a connection is not yet open are returned as + /// `Err(HashSet)`. + pub fn try_open_substream_batch( + &self, + peers: impl Iterator, + ) -> Result<(), HashSet> { + let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers + .map(|peer| match self.peers.contains_key(&peer) { + true => (None, Some(peer)), + false => (Some(peer), None), + }) + .unzip(); + + let to_add = to_add.into_iter().flatten().collect::>(); + let to_ignore = to_ignore.into_iter().flatten().collect::>(); + + tracing::trace!( + target: LOG_TARGET, + peers_to_add = ?to_add.len(), + peers_to_ignore = ?to_ignore.len(), + "open substream", + ); + + self.command_tx + .try_send(NotificationCommand::OpenSubstream { + peers: to_add.clone(), + }) + .map_err(|_| to_add) + } + + /// Close substream to `peer`. + pub async fn close_substream(&self, peer: PeerId) { + tracing::trace!(target: LOG_TARGET, ?peer, "close substream"); + + if !self.peers.contains_key(&peer) { + return; + } + + let _ = self + .command_tx + .send(NotificationCommand::CloseSubstream { + peers: HashSet::from_iter([peer]), + }) + .await; + } + + /// Close substream to multiple peers. + /// + /// Similar to [`NotificationHandle::close_substream()`] but multiple substreams are closed + /// using a single call to `NotificationProtocol`. + pub async fn close_substream_batch(&self, peers: impl Iterator) { + let peers = peers + .filter_map(|peer| self.peers.contains_key(&peer).then_some(peer)) + .collect::>(); + + if peers.is_empty() { + return; + } + + tracing::trace!( + target: LOG_TARGET, + ?peers, + "close substreams", + ); + + let _ = self.command_tx.send(NotificationCommand::CloseSubstream { peers }).await; + } + + /// Try close substream to multiple peers. + /// + /// Similar to [`NotificationHandle::close_substream()`] but multiple substreams are closed + /// using a single call to `NotificationProtocol`. + /// + /// If the channel is clogged, `peers` is returned as `Err(HashSet)`. + /// + /// If `peers` is empty after filtering all already-connected peers, + /// `Err(HashMap::new())` is returned. + pub fn try_close_substream_batch( + &self, + peers: impl Iterator, + ) -> Result<(), HashSet> { + let peers = peers + .filter_map(|peer| self.peers.contains_key(&peer).then_some(peer)) + .collect::>(); + + if peers.is_empty() { + return Err(HashSet::new()); + } + + tracing::trace!( + target: LOG_TARGET, + ?peers, + "close substreams", + ); + + self.command_tx + .try_send(NotificationCommand::CloseSubstream { + peers: peers.clone(), + }) + .map_err(|_| peers) + } + + /// Set new handshake. + pub fn set_handshake(&mut self, handshake: Vec) { + tracing::trace!(target: LOG_TARGET, ?handshake, "set handshake"); + + *self.handshake.write() = handshake; + } + + /// Send validation result to the notification protocol for an inbound substream received from + /// `peer`. + pub fn send_validation_result(&mut self, peer: PeerId, result: ValidationResult) { + tracing::trace!(target: LOG_TARGET, ?peer, ?result, "send validation result"); + + self.pending_validations.remove(&peer).map(|tx| tx.send(result)); + } + + /// Send notification to `peer` synchronously. + /// + /// If the channel is clogged, [`NotificationError::ChannelClogged`] is returned. + pub fn send_sync_notification( + &mut self, + peer: PeerId, + notification: Vec, + ) -> Result<(), NotificationError> { + match self.peers.get_mut(&peer) { + Some(sink) => match sink.send_sync_notification(notification) { + Ok(()) => Ok(()), + Err(error) => match error { + NotificationError::NoConnection => return Err(NotificationError::NoConnection), + NotificationError::ChannelClogged => { + let _ = self.clogged.insert(peer).then(|| { + self.command_tx.try_send(NotificationCommand::ForceClose { peer }) + }); + + Err(NotificationError::ChannelClogged) + } + // sink doesn't emit any other `NotificationError`s + _ => unreachable!(), + }, + }, + None => Ok(()), + } + } + + /// Send notification to `peer` asynchronously, waiting for the channel to have capacity + /// if it's clogged. + /// + /// Returns [`Error::PeerDoesntExist(PeerId)`](crate::error::Error::PeerDoesntExist) if the + /// connection has been closed. + pub async fn send_async_notification( + &mut self, + peer: PeerId, + notification: Vec, + ) -> crate::Result<()> { + match self.peers.get_mut(&peer) { + Some(sink) => sink.send_async_notification(notification).await, + None => Err(Error::PeerDoesntExist(peer)), + } + } + + /// Get a copy of the underlying notification sink for the peer. + /// + /// `None` is returned if `peer` doesn't exist. + pub fn notification_sink(&self, peer: PeerId) -> Option { + self.peers.get(&peer).and_then(|sink| Some(sink.clone())) + } } impl Stream for NotificationHandle { - type Item = NotificationEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match self.event_rx.poll_recv(cx) { - Poll::Pending => {}, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(event)) => match event { - InnerNotificationEvent::NotificationStreamOpened { - protocol, - fallback, - direction, - peer, - handshake, - sink, - } => { - self.peers.insert(peer, sink); - - return Poll::Ready(Some(NotificationEvent::NotificationStreamOpened { - protocol, - fallback, - direction, - peer, - handshake, - })); - }, - InnerNotificationEvent::NotificationStreamClosed { peer } => { - self.peers.remove(&peer); - self.clogged.remove(&peer); - - return Poll::Ready(Some(NotificationEvent::NotificationStreamClosed { - peer, - })); - }, - InnerNotificationEvent::ValidateSubstream { - protocol, - fallback, - peer, - handshake, - tx, - } => { - self.pending_validations.insert(peer, tx); - - return Poll::Ready(Some(NotificationEvent::ValidateSubstream { - protocol, - fallback, - peer, - handshake, - })); - }, - InnerNotificationEvent::NotificationStreamOpenFailure { peer, error } => - return Poll::Ready(Some(NotificationEvent::NotificationStreamOpenFailure { - peer, - error, - })), - }, - } - - match futures::ready!(self.notif_rx.poll_recv(cx)) { - None => return Poll::Ready(None), - Some((peer, notification)) => - if self.peers.contains_key(&peer) { - return Poll::Ready(Some(NotificationEvent::NotificationReceived { - peer, - notification, - })); - }, - } - } - } + type Item = NotificationEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.event_rx.poll_recv(cx) { + Poll::Pending => {} + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(event)) => match event { + InnerNotificationEvent::NotificationStreamOpened { + protocol, + fallback, + direction, + peer, + handshake, + sink, + } => { + self.peers.insert(peer, sink); + + return Poll::Ready(Some(NotificationEvent::NotificationStreamOpened { + protocol, + fallback, + direction, + peer, + handshake, + })); + } + InnerNotificationEvent::NotificationStreamClosed { peer } => { + self.peers.remove(&peer); + self.clogged.remove(&peer); + + return Poll::Ready(Some(NotificationEvent::NotificationStreamClosed { + peer, + })); + } + InnerNotificationEvent::ValidateSubstream { + protocol, + fallback, + peer, + handshake, + tx, + } => { + self.pending_validations.insert(peer, tx); + + return Poll::Ready(Some(NotificationEvent::ValidateSubstream { + protocol, + fallback, + peer, + handshake, + })); + } + InnerNotificationEvent::NotificationStreamOpenFailure { peer, error } => + return Poll::Ready(Some( + NotificationEvent::NotificationStreamOpenFailure { peer, error }, + )), + }, + } + + match futures::ready!(self.notif_rx.poll_recv(cx)) { + None => return Poll::Ready(None), + Some((peer, notification)) => + if self.peers.contains_key(&peer) { + return Poll::Ready(Some(NotificationEvent::NotificationReceived { + peer, + notification, + })); + }, + } + } + } } diff --git a/src/protocol/notification/mod.rs b/src/protocol/notification/mod.rs index e3e29c53..589555e4 100644 --- a/src/protocol/notification/mod.rs +++ b/src/protocol/notification/mod.rs @@ -21,29 +21,29 @@ //! Notification protocol implementation. use crate::{ - error::Error, - executor::Executor, - protocol::{ - self, - notification::{ - connection::Connection, - handle::NotificationEventHandle, - negotiation::{HandshakeEvent, HandshakeService}, - types::NotificationCommand, - }, - TransportEvent, TransportService, - }, - substream::Substream, - types::{protocol::ProtocolName, SubstreamId}, - PeerId, DEFAULT_CHANNEL_SIZE, + error::Error, + executor::Executor, + protocol::{ + self, + notification::{ + connection::Connection, + handle::NotificationEventHandle, + negotiation::{HandshakeEvent, HandshakeService}, + types::NotificationCommand, + }, + TransportEvent, TransportService, + }, + substream::Substream, + types::{protocol::ProtocolName, SubstreamId}, + PeerId, DEFAULT_CHANNEL_SIZE, }; use bytes::BytesMut; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use multiaddr::Multiaddr; use tokio::sync::{ - mpsc::{channel, Receiver, Sender}, - oneshot, + mpsc::{channel, Receiver, Sender}, + oneshot, }; use std::{collections::HashMap, sync::Arc, time::Duration}; @@ -70,1678 +70,1725 @@ const LOG_TARGET: &str = "litep2p::notification"; /// See [`PeerState::PendingValidation.`] for more details. #[derive(Debug, PartialEq, Eq)] enum ConnectionState { - /// There is a active, transport-level connection open to the peer. - Open, + /// There is a active, transport-level connection open to the peer. + Open, - /// There is no transport-level connection open to the peer. - Closed, + /// There is no transport-level connection open to the peer. + Closed, } /// Inbound substream state. #[derive(Debug)] enum InboundState { - /// Substream is closed. - Closed, - - /// Handshake is being read from the remote node. - ReadingHandshake, - - /// Substream and its handshake are being validated by the user protocol. - Validating { - /// Inbound substream. - inbound: Substream, - }, - - /// Handshake is being sent to the remote node. - SendingHandshake, - - /// Substream is open. - Open { - /// Inbound substream. - inbound: Substream, - }, + /// Substream is closed. + Closed, + + /// Handshake is being read from the remote node. + ReadingHandshake, + + /// Substream and its handshake are being validated by the user protocol. + Validating { + /// Inbound substream. + inbound: Substream, + }, + + /// Handshake is being sent to the remote node. + SendingHandshake, + + /// Substream is open. + Open { + /// Inbound substream. + inbound: Substream, + }, } /// Outbound substream state. #[derive(Debug)] enum OutboundState { - /// Substream is closed. - Closed, - - /// Outbound substream initiated. - OutboundInitiated { - /// Substream ID. - substream: SubstreamId, - }, - - /// Substream is in the state of being negotiated. - /// - /// This process entails sending local node's handshake and reading back the remote node's - /// handshake if they've accepted the substream or detecting that the substream was closed - /// in case the substream was rejected. - Negotiating, - - /// Substream is open. - Open { - /// Received handshake. - handshake: Vec, - - /// Outbound substream. - outbound: Substream, - }, + /// Substream is closed. + Closed, + + /// Outbound substream initiated. + OutboundInitiated { + /// Substream ID. + substream: SubstreamId, + }, + + /// Substream is in the state of being negotiated. + /// + /// This process entails sending local node's handshake and reading back the remote node's + /// handshake if they've accepted the substream or detecting that the substream was closed + /// in case the substream was rejected. + Negotiating, + + /// Substream is open. + Open { + /// Received handshake. + handshake: Vec, + + /// Outbound substream. + outbound: Substream, + }, } impl OutboundState { - /// Get pending outboud substream ID, if it exists. - fn pending_open(&self) -> Option { - match &self { - OutboundState::OutboundInitiated { substream } => Some(*substream), - _ => None, - } - } + /// Get pending outboud substream ID, if it exists. + fn pending_open(&self) -> Option { + match &self { + OutboundState::OutboundInitiated { substream } => Some(*substream), + _ => None, + } + } } #[derive(Debug)] enum PeerState { - /// Peer state is poisoned due to invalid state transition. - Poisoned, - - /// Validation for an inbound substream is still pending. - /// - /// In order to enforce valid state transitions, `NotificationProtocol` keeps track of pending - /// validations across connectivity events (open/closed) and enforces that no activity happens - /// for any peer that is still awaiting validation for their inbound substream. - /// - /// If connection closes while the substream is being validated, instead of removing peer from - /// `peers`, the peer state is set as `ValidationPending` which indicates to the state machine - /// that a response for a inbound substream is pending validation. The substream itself will be - /// dead by the time validation is received if the peer state is `ValidationPending` since the - /// substream was part of a previous, now-closed substream but this state allows - /// `NotificationProtocol` to enforce correct state transitions by, e.g., rejecting new inbound - /// substream while a previous substream is still being validated or rejecting outbound - /// substreams on new connections if that same condition holds. - ValidationPending { - /// What is current connectivity state of the peer. - /// - /// If `state` is `ConnectionState::Closed` when the validation is finally received, peer - /// is removed from `peer` and if the `state` is `ConnectionState::Open`, peer is moved to - /// state `PeerState::Closed` and user is allowed to retry opening an outbound substream. - state: ConnectionState, - }, - - /// Connection to peer is closed. - Closed { - /// Connection might have been closed while there was an outbound substream still pending. - /// - /// To handle this state transition correctly in case the substream opens after the - /// connection is considered closed, store the `SubstreamId` to that it can be verified in - /// case the substream ever opens. - pending_open: Option, - }, - - /// Peer is being dialed in order to open an outbound substream to them. - Dialing, - - /// Outbound substream initiated. - OutboundInitiated { - /// Substream ID. - substream: SubstreamId, - }, - - /// Substream is being validated. - Validating { - /// Protocol. - protocol: ProtocolName, - - /// Fallback protocol, if the substream was negotiated using a fallback name. - fallback: Option, - - /// Outbound protocol state. - outbound: OutboundState, - - /// Inbound protocol state. - inbound: InboundState, - - /// Direction. - direction: Direction, - }, - - /// Notification stream has been opened. - Open { - /// `Oneshot::Sender` for shutting down the connection. - shutdown: oneshot::Sender<()>, - }, + /// Peer state is poisoned due to invalid state transition. + Poisoned, + + /// Validation for an inbound substream is still pending. + /// + /// In order to enforce valid state transitions, `NotificationProtocol` keeps track of pending + /// validations across connectivity events (open/closed) and enforces that no activity happens + /// for any peer that is still awaiting validation for their inbound substream. + /// + /// If connection closes while the substream is being validated, instead of removing peer from + /// `peers`, the peer state is set as `ValidationPending` which indicates to the state machine + /// that a response for a inbound substream is pending validation. The substream itself will be + /// dead by the time validation is received if the peer state is `ValidationPending` since the + /// substream was part of a previous, now-closed substream but this state allows + /// `NotificationProtocol` to enforce correct state transitions by, e.g., rejecting new inbound + /// substream while a previous substream is still being validated or rejecting outbound + /// substreams on new connections if that same condition holds. + ValidationPending { + /// What is current connectivity state of the peer. + /// + /// If `state` is `ConnectionState::Closed` when the validation is finally received, peer + /// is removed from `peer` and if the `state` is `ConnectionState::Open`, peer is moved to + /// state `PeerState::Closed` and user is allowed to retry opening an outbound substream. + state: ConnectionState, + }, + + /// Connection to peer is closed. + Closed { + /// Connection might have been closed while there was an outbound substream still pending. + /// + /// To handle this state transition correctly in case the substream opens after the + /// connection is considered closed, store the `SubstreamId` to that it can be verified in + /// case the substream ever opens. + pending_open: Option, + }, + + /// Peer is being dialed in order to open an outbound substream to them. + Dialing, + + /// Outbound substream initiated. + OutboundInitiated { + /// Substream ID. + substream: SubstreamId, + }, + + /// Substream is being validated. + Validating { + /// Protocol. + protocol: ProtocolName, + + /// Fallback protocol, if the substream was negotiated using a fallback name. + fallback: Option, + + /// Outbound protocol state. + outbound: OutboundState, + + /// Inbound protocol state. + inbound: InboundState, + + /// Direction. + direction: Direction, + }, + + /// Notification stream has been opened. + Open { + /// `Oneshot::Sender` for shutting down the connection. + shutdown: oneshot::Sender<()>, + }, } /// Peer context. #[derive(Debug)] struct PeerContext { - /// Peer state. - state: PeerState, + /// Peer state. + state: PeerState, } impl PeerContext { - /// Create new [`PeerContext`]. - fn new() -> Self { - Self { state: PeerState::Closed { pending_open: None } } - } + /// Create new [`PeerContext`]. + fn new() -> Self { + Self { + state: PeerState::Closed { pending_open: None }, + } + } } pub(crate) struct NotificationProtocol { - /// Transport service. - service: TransportService, + /// Transport service. + service: TransportService, - /// Protocol. - protocol: ProtocolName, + /// Protocol. + protocol: ProtocolName, - /// Auto accept inbound substream if the outbound substream was initiated by the local node. - auto_accept: bool, + /// Auto accept inbound substream if the outbound substream was initiated by the local node. + auto_accept: bool, - /// TX channel passed to the protocol used for sending events. - event_handle: NotificationEventHandle, + /// TX channel passed to the protocol used for sending events. + event_handle: NotificationEventHandle, - /// TX channel for sending shut down notifications from connection handlers to - /// [`NotificationProtocol`]. - shutdown_tx: Sender, + /// TX channel for sending shut down notifications from connection handlers to + /// [`NotificationProtocol`]. + shutdown_tx: Sender, - /// RX channel for receiving shutdown notifications from the connection handlers. - shutdown_rx: Receiver, + /// RX channel for receiving shutdown notifications from the connection handlers. + shutdown_rx: Receiver, - /// RX channel passed to the protocol used for receiving commands. - command_rx: Receiver, + /// RX channel passed to the protocol used for receiving commands. + command_rx: Receiver, - /// TX channel given to connection handlers for sending notifications. - notif_tx: Sender<(PeerId, BytesMut)>, + /// TX channel given to connection handlers for sending notifications. + notif_tx: Sender<(PeerId, BytesMut)>, - /// Connected peers. - peers: HashMap, + /// Connected peers. + peers: HashMap, - /// Pending outboudn substreams. - pending_outbound: HashMap, + /// Pending outboudn substreams. + pending_outbound: HashMap, - /// Handshaking service which reads and writes the handshakes to inbound - /// and outbound substreams asynchronously. - negotiation: HandshakeService, + /// Handshaking service which reads and writes the handshakes to inbound + /// and outbound substreams asynchronously. + negotiation: HandshakeService, - /// Synchronous channel size. - sync_channel_size: usize, + /// Synchronous channel size. + sync_channel_size: usize, - /// Asynchronous channel size. - async_channel_size: usize, + /// Asynchronous channel size. + async_channel_size: usize, - /// Executor for connection handlers. - executor: Arc, + /// Executor for connection handlers. + executor: Arc, - /// Pending substream validations. - pending_validations: FuturesUnordered>, + /// Pending substream validations. + pending_validations: FuturesUnordered>, - /// Timers for pending outbound substreams. - timers: FuturesUnordered>, + /// Timers for pending outbound substreams. + timers: FuturesUnordered>, - /// Should `NotificationProtocol` attempt to dial the peer. - should_dial: bool, + /// Should `NotificationProtocol` attempt to dial the peer. + should_dial: bool, } impl NotificationProtocol { - pub(crate) fn new( - service: TransportService, - config: Config, - executor: Arc, - ) -> Self { - let (shutdown_tx, shutdown_rx) = channel(DEFAULT_CHANNEL_SIZE); - - Self { - service, - shutdown_tx, - shutdown_rx, - executor, - peers: HashMap::new(), - protocol: config.protocol_name, - auto_accept: config.auto_accept, - pending_validations: FuturesUnordered::new(), - timers: FuturesUnordered::new(), - event_handle: NotificationEventHandle::new(config.event_tx), - notif_tx: config.notif_tx, - command_rx: config.command_rx, - pending_outbound: HashMap::new(), - negotiation: HandshakeService::new(config.handshake), - sync_channel_size: config.sync_channel_size, - async_channel_size: config.async_channel_size, - should_dial: config.should_dial, - } - } - - /// Connection established to remote node. - /// - /// If the peer already exists, the only valid state for it is `Dialing` as it indicates that - /// the user tried to open a substream to a peer who was not connected to local node. - /// - /// Any other state indicates that there's an error in the state transition logic. - async fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection established"); - - let Some(context) = self.peers.get_mut(&peer) else { - self.peers.insert(peer, PeerContext::new()); - return Ok(()); - }; - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Dialing => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "dial succeeded, open substream to peer", - ); - - context.state = PeerState::Closed { pending_open: None }; - self.on_open_substream(peer).await - }, - // connection established but validation is still pending - // - // update the connection state so that `NotificationProtocol` can proceed - // to correct state after the validation result has beern received - PeerState::ValidationPending { state } => { - debug_assert_eq!(state, ConnectionState::Closed); - - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "new connection established while validation still pending", - ); - - context.state = PeerState::ValidationPending { state: ConnectionState::Open }; - - Ok(()) - }, - state => { - tracing::error!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?state, - "state mismatch: peer already exists", - ); - debug_assert!(false); - Err(Error::PeerAlreadyExists(peer)) - }, - } - } - - /// Connection closed to remote node. - /// - /// If the connection was considered open (both substreams were open), user is notified that - /// the notification stream was closed. - /// - /// If the connection was still in progress (either substream was not fully open), the user is - /// reported about it only if they had opened an outbound substream (outbound is either fully - /// open, it had been initiated or the substream was under negotiation). - async fn on_connection_closed(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection closed"); - - let Some(context) = self.peers.remove(&peer) else { - tracing::error!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "state mismatch: peer doesn't exist", - ); - debug_assert!(false); - return Err(Error::PeerDoesntExist(peer)); - }; - - // clean up all pending state for the peer - self.negotiation.remove_outbound(&peer); - self.negotiation.remove_inbound(&peer); - - match context.state { - // outbound initiated, report open failure to peer - PeerState::OutboundInitiated { .. } => { - self.event_handle - .report_notification_stream_open_failure(peer, NotificationError::Rejected) - .await; - }, - // substream fully open, report that the notification stream is closed - PeerState::Open { shutdown } => { - let _ = shutdown.send(()); - }, - // if the substream was being validated, user must be notified that the substream is - // now considered rejected if they had been made aware of the existence of the pending - // connection - PeerState::Validating { outbound, inbound, .. } => { - match (outbound, inbound) { - // substream was being validated by the protocol when the connection was closed - (OutboundState::Closed, InboundState::Validating { .. }) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "connection closed while validation pending", - ); - - self.peers.insert( - peer, - PeerContext { - state: PeerState::ValidationPending { - state: ConnectionState::Closed, - }, - }, - ); - }, - // user either initiated an outbound substream or an outbound substream was - // opened/being opened as a result of an accepted inbound substream but was not - // yet fully open - // - // to have consistent state tracking in the user protocol, substream rejection - // must be reported to the user - ( - OutboundState::OutboundInitiated { .. } | - OutboundState::Negotiating | - OutboundState::Open { .. }, - _, - ) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "connection closed outbound substream under negotiation", - ); - - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::Rejected, - ) - .await; - }, - (_, _) => {}, - } - }, - // pending validations must be tracked across connection open/close events - PeerState::ValidationPending { .. } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "validation pending while connection closed", - ); - - self.peers.insert( - peer, - PeerContext { - state: PeerState::ValidationPending { state: ConnectionState::Closed }, - }, - ); - }, - _ => {}, - } - - Ok(()) - } - - /// Local node opened a substream to remote node. - /// - /// The connection can be in three different states: - /// - this is the first substream that was opened and thus the connection was initiated by the - /// local node - /// - this is a response to a previously received inbound substream which the local node - /// accepted and as a result, opened its own substream - /// - local and remote nodes opened substreams at the same time - /// - /// In the first case, the local node's handshake is sent to remote node and the substream is - /// polled in the background until they either send their handshake or close the substream. - /// - /// For the second case, the connection was initiated by the remote node and the substream was - /// accepted by the local node which initiated an outbound substream to the remote node. - /// The only valid states for this case are [`InboundState::Open`], - /// and [`InboundState::SendingHandshake`] as they imply - /// that the inbound substream have been accepted by the local node and this opened outbound - /// substream is a result of a valid state transition. - /// - /// For the third case, if the nodes have opened substreams at the same time, the outbound state - /// must be [`OutboundState::OutboundInitiated`] to ascertain that the an outbound substream was - /// actually opened. Any other state would be a state mismatch and would mean that the - /// connection is opening substreams without the permission of the protocol handler. - async fn on_outbound_substream( - &mut self, - protocol: ProtocolName, - fallback: Option, - peer: PeerId, - substream_id: SubstreamId, - outbound: Substream, - ) -> crate::Result<()> { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?protocol, - ?substream_id, - "handle outbound substream", - ); - - // peer must exist since an outbound substream was received from them - let context = self.peers.get_mut(&peer).expect("peer to exist"); - let pending_peer = self.pending_outbound.remove(&substream_id); - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - // the connection was initiated by the local node, send handshake to remote and wait to - // receive their handshake back - PeerState::OutboundInitiated { substream } => { - debug_assert!(substream == substream_id); - debug_assert!(pending_peer == Some(peer)); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?fallback, - ?substream_id, - "negotiate outbound protocol", - ); - - self.negotiation.negotiate_outbound(peer, outbound); - context.state = PeerState::Validating { - protocol, - fallback, - inbound: InboundState::Closed, - outbound: OutboundState::Negotiating, - direction: Direction::Outbound, - }; - }, - PeerState::Validating { - protocol, - fallback, - inbound, - direction, - outbound: outbound_state, - } => { - // the inbound substream has been accepted by the local node since the handshake has - // been read and the local handshake has either already been sent or - // it's in the process of being sent. - match inbound { - InboundState::SendingHandshake | InboundState::Open { .. } => { - context.state = PeerState::Validating { - protocol, - fallback, - inbound, - direction, - outbound: OutboundState::Negotiating, - }; - self.negotiation.negotiate_outbound(peer, outbound); - }, - // nodes have opened substreams at the same time - inbound_state => match outbound_state { - OutboundState::OutboundInitiated { substream } => { - debug_assert!(substream == substream_id); - - context.state = PeerState::Validating { - protocol, - fallback, - direction, - inbound: inbound_state, - outbound: OutboundState::Negotiating, - }; - self.negotiation.negotiate_outbound(peer, outbound); - }, - // invalid state: more than one outbound substream has been opened - inner_state => { - tracing::error!( - target: LOG_TARGET, - ?peer, - %protocol, - ?substream_id, - ?inbound_state, - ?inner_state, - "invalid state, expected `OutboundInitiated`", - ); - - let _ = outbound.close().await; - debug_assert!(false); - }, - }, - } - }, - // the connection may have been closed while an outbound substream was pending - // if the outbound substream was initiated successfully, close it and reset - // `pending_open` - PeerState::Closed { pending_open } if pending_open == Some(substream_id) => { - let _ = outbound.close().await; - - context.state = PeerState::Closed { pending_open: None }; - }, - state => { - tracing::error!( - target: LOG_TARGET, - ?peer, - %protocol, - ?substream_id, - ?state, - "invalid state: more than one outbound substream opened", - ); - - let _ = outbound.close().await; - debug_assert!(false); - }, - } - - Ok(()) - } - - /// Remote opened a substream to local node. - /// - /// The peer can be in four different states for the inbound substream to be considered valid: - /// - the connection is closed - /// - conneection is open but substream validation from a previous connection is still pending - /// - outbound substream has been opened but not yet acknowledged by the remote peer - /// - outbound substream has been opened and acknowledged by the remote peer and it's being - /// negotiated - /// - /// If remote opened more than one substream, the new substream is simply discarded. - async fn on_inbound_substream( - &mut self, - protocol: ProtocolName, - fallback: Option, - peer: PeerId, - substream: Substream, - ) -> crate::Result<()> { - // peer must exist since an inbound substream was received from them - let context = self.peers.get_mut(&peer).expect("peer to exist"); - - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - state = ?context.state, - "handle inbound substream", - ); - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - // inbound substream of a previous connection is still pending validation, - // reject any new inbound substreams until an answer is heard from the user - state @ PeerState::ValidationPending { .. } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - ?state, - "validation for previous substream still pending", - ); - - let _ = substream.close().await; - context.state = state; - }, - // outbound substream for previous connection still pending, reject inbound substream - // and wait for the outbound substream state to conclude as either succeeded or failed - // before accepting any inbound substreams. - PeerState::Closed { pending_open: Some(substream_id) } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "received inbound substream while outbound substream opening, rejecting", - ); - let _ = substream.close().await; - - context.state = PeerState::Closed { pending_open: Some(substream_id) }; - }, - // the peer state is closed so this is a fresh inbound substream. - PeerState::Closed { pending_open: None } => { - self.negotiation.read_handshake(peer, substream); - - context.state = PeerState::Validating { - protocol, - fallback, - direction: Direction::Inbound, - inbound: InboundState::ReadingHandshake, - outbound: OutboundState::Closed, - }; - }, - // if the connection is under validation (so an outbound substream has been opened and - // it's still pending or under negotiation), the only valid state for the - // inbound state is closed as it indicates that there isn't an inbound substream yet for - // the remote node duplicate substreams are prohibited. - PeerState::Validating { - protocol, - fallback, - outbound, - direction, - inbound: InboundState::Closed, - } => { - self.negotiation.read_handshake(peer, substream); - - context.state = PeerState::Validating { - protocol, - fallback, - outbound, - direction, - inbound: InboundState::ReadingHandshake, - }; - }, - // outbound substream may have been initiated by the local node while a remote node also - // opened a substream roughly at the same time - PeerState::OutboundInitiated { substream: outbound } => { - self.negotiation.read_handshake(peer, substream); - - context.state = PeerState::Validating { - protocol, - fallback, - direction: Direction::Outbound, - outbound: OutboundState::OutboundInitiated { substream: outbound }, - inbound: InboundState::ReadingHandshake, - }; - }, - // new inbound substream opend while validation for the previous substream was still - // pending - // - // the old substream can be considered dead because remote wouldn't open a new substream - // to us unless they had discarded the previous substream. - // - // set peer state to `ValidationPending` to indicate that the peer is "blocked" until a - // validation for the substream is heard, blocking any further activity for - // the connection and once the validation is received and in case the - // substream is accepted, it will be reported as open failure to to the peer - // because the states have gone out of sync. - PeerState::Validating { - outbound: OutboundState::Closed, - inbound: InboundState::Validating { inbound: pending_substream }, - .. - } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "remote opened substream while previous was still pending, connection failed", - ); - let _ = substream.close().await; - let _ = pending_substream.close().await; - - context.state = PeerState::ValidationPending { state: ConnectionState::Open }; - }, - // remote opened another inbound substream, close it and otherwise ignore the event - // as this is a non-serious protocol violation. - state => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - ?state, - "remote opened more than one inbound substreams, discarding", - ); - - let _ = substream.close().await; - context.state = state; - }, - } - - Ok(()) - } - - /// Failed to open substream to remote node. - /// - /// If the substream was initiated by the local node, it must be reported that the substream - /// failed to open. Otherwise the peer state can silently be converted to `Closed`. - async fn on_substream_open_failure(&mut self, substream_id: SubstreamId, error: Error) { - tracing::debug!( - target: LOG_TARGET, - protocol = %self.protocol, - ?substream_id, - ?error, - "failed to open substream" - ); - - let Some(peer) = self.pending_outbound.remove(&substream_id) else { - tracing::warn!( - target: LOG_TARGET, - protocol = %self.protocol, - ?substream_id, - "pending outbound substream doesn't exist", - ); - debug_assert!(false); - return; - }; - - // peer must exist since an outbound substream failure was received from them - let Some(context) = self.peers.get_mut(&peer) else { - tracing::warn!(target: LOG_TARGET, ?peer, "peer doesn't exist"); - debug_assert!(false); - return; - }; - - match &mut context.state { - PeerState::OutboundInitiated { .. } => { - context.state = PeerState::Closed { pending_open: None }; - - self.event_handle - .report_notification_stream_open_failure(peer, NotificationError::Rejected) - .await; - }, - // if the substream was accepted by the local node and as a result, an outbound - // substream was accepted as a result this should not be reported to local node - PeerState::Validating { outbound, .. } => { - self.negotiation.remove_inbound(&peer); - self.negotiation.remove_outbound(&peer); - - let pending_open = match outbound { - OutboundState::Closed => None, - OutboundState::OutboundInitiated { substream } => { - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::Rejected, - ) - .await; - - Some(*substream) - }, - OutboundState::Negotiating | OutboundState::Open { .. } => { - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::Rejected, - ) - .await; - - None - }, - }; - - context.state = PeerState::Closed { pending_open }; - }, - PeerState::Closed { pending_open } => { - tracing::debug!( - target: LOG_TARGET, - protocol = %self.protocol, - ?substream_id, - "substream open failure for a closed connection", - ); - debug_assert_eq!(pending_open, &Some(substream_id)); - context.state = PeerState::Closed { pending_open: None }; - }, - state => { - tracing::warn!( - target: LOG_TARGET, - protocol = %self.protocol, - ?substream_id, - ?state, - "invalid state for outbound substream open failure", - ); - context.state = PeerState::Closed { pending_open: None }; - debug_assert!(false); - }, - } - } - - /// Open substream to remote `peer`. - /// - /// Outbound substream can opened only if the `PeerState` is `Closed`. - /// By forcing the substream to be opened only if the state is currently closed, - /// `NotificationProtocol` can enfore more predictable state transitions. - /// - /// Other states either imply an invalid state transition ([`PeerState::Open`]) or that an - /// inbound substream has already been received and its currently being validated by the user. - async fn on_open_substream(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "open substream"); - - let Some(context) = self.peers.get_mut(&peer) else { - if !self.should_dial { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "connection to peer not open and dialing disabled", - ); - - self.event_handle - .report_notification_stream_open_failure(peer, NotificationError::DialFailure) - .await; - return Ok(()); - } - - match self.service.dial(&peer) { - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?error, - "failed to dial peer", - ); - - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::DialFailure, - ) - .await; - - return Err(error); - }, - Ok(()) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "started to dial peer", - ); - - self.peers.insert(peer, PeerContext { state: PeerState::Dialing }); - return Ok(()); - }, - } - }; - - match context.state { - // protocol can only request a new outbound substream to be opened if the state is - // `Closed` other states imply that it's already open - PeerState::Closed { pending_open: Some(substream_id) } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?substream_id, - "outbound substream opening, reusing pending open substream", - ); - - self.pending_outbound.insert(substream_id, peer); - context.state = PeerState::OutboundInitiated { substream: substream_id }; - }, - PeerState::Closed { .. } => match self.service.open_substream(peer) { - Ok(substream_id) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?substream_id, - "outbound substream opening", - ); - - self.pending_outbound.insert(substream_id, peer); - context.state = PeerState::OutboundInitiated { substream: substream_id }; - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?error, - "failed to open substream", - ); - - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::NoConnection, - ) - .await; - context.state = PeerState::Closed { pending_open: None }; - }, - }, - // while a validation is pending for an inbound substream, user is not allowed to open - // any outbound substreams until the old inbond substream is either accepted or rejected - PeerState::ValidationPending { .. } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "validation still pending, rejecting outbound substream request", - ); - - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::ValidationPending, - ) - .await; - }, - _ => {}, - } - - Ok(()) - } - - /// Close substream to remote `peer`. - /// - /// This function can only be called if the substream was actually open, any other state is - /// unreachable as the user is unable to emit this command to [`NotificationProtocol`] unless - /// the connection has been fully opened. - async fn on_close_substream(&mut self, peer: PeerId) { - tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "close substream"); - - let Some(context) = self.peers.get_mut(&peer) else { - tracing::debug!(target: LOG_TARGET, ?peer, "peer doesn't exist"); - return; - }; - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Open { shutdown } => { - let _ = shutdown.send(()); - - context.state = PeerState::Closed { pending_open: None }; - }, - state => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?state, - "substream already closed", - ); - context.state = state; - }, - } - } - - /// Handle validation result. - /// - /// The validation result binary (accept/reject). If the node is rejected, the substreams are - /// discarded and state is set to `PeerState::Closed`. If there was an outbound substream in - /// progress while the connection was rejected by the user, the oubound state is discarded, - /// except for the substream ID of the substream which is kept for later use, in case the - /// substream happens to open. - /// - /// If the node is accepted and there is no outbound substream to them open yet, a new substream - /// is opened and once it opens, the local handshake will be sent to the remote peer and if - /// they also accept the substream the connection is considered fully open. - async fn on_validation_result( - &mut self, - peer: PeerId, - result: ValidationResult, - ) -> crate::Result<()> { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?result, - "handle validation result", - ); - - let Some(context) = self.peers.get_mut(&peer) else { - tracing::debug!(target: LOG_TARGET, ?peer, "peer doesn't exist"); - return Err(Error::PeerDoesntExist(peer)); - }; - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Validating { - protocol, - fallback, - outbound, - direction, - inbound: InboundState::Validating { inbound }, - } => match result { - // substream was rejected by the local node, if an outbound substream was under - // negotation, discard that data and if an outbound substream was - // initiated, save the `SubstreamId` of that substream and later if the substream - // is opened, the state can be corrected to `pending_open: None`. - ValidationResult::Reject => { - let _ = inbound.close().await; - self.negotiation.remove_outbound(&peer); - self.negotiation.remove_inbound(&peer); - context.state = PeerState::Closed { pending_open: outbound.pending_open() }; - - Ok(()) - }, - ValidationResult::Accept => match outbound { - // no outbound substream exists so initiate a new substream open and send the - // local handshake to remote node, indicating that the - // connection was accepted by the local node - OutboundState::Closed => match self.service.open_substream(peer) { - Ok(substream) => { - self.negotiation.send_handshake(peer, inbound); - self.pending_outbound.insert(substream, peer); - - context.state = PeerState::Validating { - protocol, - fallback, - direction, - inbound: InboundState::SendingHandshake, - outbound: OutboundState::OutboundInitiated { substream }, - }; - Ok(()) - }, - // failed to open outbound substream after accepting an inbound substream - // - // since the user was notified of this substream and they accepted it, - // they expecting some kind of answer (open success/failure). - // - // report to user that the substream failed to open so they can track the - // state transitions of the peer correctly - Err(error) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?result, - ?error, - "failed to open outbound substream for accepted substream", - ); - - let _ = inbound.close().await; - context.state = PeerState::Closed { pending_open: None }; - - self.event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::Rejected, - ) - .await; - - Err(error) - }, - }, - // here the state is one of `OutboundState::{OutboundInitiated, Negotiating, - // Open}` so that state can be safely ignored and all that - // has to be done is to send the local handshake to remote - // node to indicate that the connection was accepted. - _ => { - self.negotiation.send_handshake(peer, inbound); - - context.state = PeerState::Validating { - protocol, - fallback, - direction, - inbound: InboundState::SendingHandshake, - outbound, - }; - Ok(()) - }, - }, - }, - // validation result received for an inbound substream which is now considered dead - // because while the substream was being validated, the connection had closed. - // - // if the substream was rejected and there is no active connection to the peer, - // just remove the peer from `peers` without informing user - // - // if the substream was accepted, the user must be informed that the substream failed to - // open. Depending on whether there is currently a connection open to the peer, either - // report `Rejected`/`NoConnection` and let the user try again. - PeerState::ValidationPending { state } => { - if let Some(error) = match state { - ConnectionState::Open => { - context.state = PeerState::Closed { pending_open: None }; - - std::matches!(result, ValidationResult::Accept) - .then_some(NotificationError::Rejected) - }, - ConnectionState::Closed => { - self.peers.remove(&peer); - - std::matches!(result, ValidationResult::Accept) - .then_some(NotificationError::NoConnection) - }, - } { - self.event_handle.report_notification_stream_open_failure(peer, error).await; - } - - Ok(()) - }, - // if the user incorrectly send a validation result for a peer that doesn't require - // validation, set state back to what it was and ignore the event - // - // the user protocol may send a stale validation result not because of a programming - // error but because it has a backlock of unhandled events, with one event potentially - // nullifying the need for substream validation, and is just temporarily out of sync - // with `NotificationProtocol` - state => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?state, - "validation result received for peer that doesn't require validation", - ); - - context.state = state; - Ok(()) - }, - } - } - - /// Handle handshake event. - /// - /// There are three different handshake event types: - /// - outbound substream negotiated - /// - inbound substream negotiated - /// - substream negotiation error - /// - /// Neither outbound nor inbound substream negotiated automatically means that the connection is - /// considered open as both substreams must be fully negotiated for that to be the case. That is - /// why the peer state for inbound and outbound are set separately and at the end of the - /// function is the collective state of the substreams checked and if both substreams are - /// negotiated, the user informed that the connection is open. - /// - /// If the negotiation fails, the user may have to be informed of that. Outbound substream - /// failure always results in user getting notified since the existence of an outbound substream - /// means that the user has either initiated an outbound substreams or has accepted an inbound - /// substreams, resulting in an outbound substreams. - /// - /// Negotiation failure for inbound substreams which are in the state - /// [`InboundState::ReadingHandshake`] don't result in any notification because while the - /// handshake is being read from the substream, the user is oblivious to the fact that an - /// inbound substream has even been received. - async fn on_handshake_event(&mut self, peer: PeerId, event: HandshakeEvent) { - let Some(context) = self.peers.get_mut(&peer) else { - tracing::error!(target: LOG_TARGET, "invalid state: negotiation event received but peer doesn't exist"); - debug_assert!(false); - return; - }; - - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?event, - "handle handshake event", - ); - - match event { - // either an inbound or outbound substream has been negotiated successfully - HandshakeEvent::Negotiated { peer, handshake, substream, direction } => match direction - { - // outbound substream was negotiated, the only valid state for peer is `Validating` - // and only valid state for `OutboundState` is `Negotiating` - negotiation::Direction::Outbound => { - self.negotiation.remove_outbound(&peer); - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Validating { - protocol, - fallback, - direction, - outbound: OutboundState::Negotiating, - inbound, - } => { - context.state = PeerState::Validating { - protocol, - fallback, - direction, - outbound: OutboundState::Open { handshake, outbound: substream }, - inbound, - }; - }, - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?state, - "outbound substream negotiated but peer has invalid state", - ); - debug_assert!(false); - }, - } - }, - // inbound negotiation event completed - // - // the negotiation event can be on of two different types: - // - remote handshake was read from the substream - // - local handshake has been sent to remote node - // - // For the first case, the substream has to be validated by the local node. - // This means reporting the protocol name, potential negotiated fallback and the - // handshake. Local node will then either accept or reject the substream which is - // handled by [`NotificationProtocol::on_validation_result()`]. Compared to - // Substrate, litep2p requires both peers to validate the inbound handshake to allow - // more complex connection validation. If this is not necessary and the protocol - // wishes to auto-accept the inbound substreams that are a result of - // an outbound substream already accepted by the remote node, the - // substream validation is skipped and the local handshake is sent - // right away. - // - // For the second case, the local handshake was sent to remote node successfully and - // the inbound substream is considered open and if the outbound - // substream is open as well, the connection is fully open. - // - // Only valid states for [`InboundState`] are [`InboundState::ReadingHandshake`] and - // [`InboundState::SendingHandshake`] because otherwise the inbound - // substream cannot be in [`HandshakeService`](super::negotiation::HandshakeService) - // unless there is a logic bug in the state machine. - negotiation::Direction::Inbound => { - self.negotiation.remove_inbound(&peer); - - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Validating { - protocol, - fallback, - direction, - outbound, - inbound: InboundState::ReadingHandshake, - } => { - if !std::matches!(outbound, OutboundState::Closed) && self.auto_accept { - tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - ?direction, - ?outbound, - "auto-accept inbound substream", - ); - - self.negotiation.send_handshake(peer, substream); - context.state = PeerState::Validating { - protocol, - fallback, - direction, - inbound: InboundState::SendingHandshake, - outbound, - }; - - return; - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - ?outbound, - "send inbound protocol for validation", - ); - - context.state = PeerState::Validating { - protocol: protocol.clone(), - fallback: fallback.clone(), - inbound: InboundState::Validating { inbound: substream }, - outbound, - direction, - }; - - let (tx, rx) = oneshot::channel(); - self.pending_validations.push(Box::pin(async move { - match rx.await { - Ok(ValidationResult::Accept) => - (peer, ValidationResult::Accept), - _ => (peer, ValidationResult::Reject), - } - })); - - self.event_handle - .report_inbound_substream( - protocol, - fallback, - peer, - handshake.into(), - tx, - ) - .await; - }, - PeerState::Validating { - protocol, - fallback, - direction, - inbound: InboundState::SendingHandshake, - outbound, - } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - "inbound substream negotiated, waiting for outbound substream to complete", - ); - - context.state = PeerState::Validating { - protocol: protocol.clone(), - fallback: fallback.clone(), - inbound: InboundState::Open { inbound: substream }, - outbound, - direction, - }; - }, - _state => debug_assert!(false), - } - }, - }, - // error occurred during negotiation, eitehr for inbound or outbound substream - // user is notified of the error only if they've either initiated an outbound substream - // or if they accepted an inbound substream and as a result initiated an outbound - // substream. - HandshakeEvent::NegotiationError { peer, direction } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?direction, - state = ?context.state, - "failed to negotiate substream", - ); - let _ = self.negotiation.remove_outbound(&peer); - let _ = self.negotiation.remove_inbound(&peer); - - // if an outbound substream had been initiated (whatever its state is), it means - // that the user knows about the connection and must be notified that it failed to - // negotiate. - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Validating { outbound, .. } => { - context.state = PeerState::Closed { pending_open: outbound.pending_open() }; - - // notify user if the outbound substream is not considered closed - if !std::matches!(outbound, OutboundState::Closed) { - return self - .event_handle - .report_notification_stream_open_failure( - peer, - NotificationError::Rejected, - ) - .await; - } - }, - _state => debug_assert!(false), - } - }, - } - - // if both inbound and outbound substreams are considered open, notify the user that - // a notification stream has been opened and set up for sending and receiving - // notifications to and from remote node - match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Validating { - protocol, - fallback, - direction, - outbound: OutboundState::Open { handshake, outbound }, - inbound: InboundState::Open { inbound }, - } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?fallback, - "notification stream opened", - ); - - let (async_tx, async_rx) = channel(self.async_channel_size); - let (sync_tx, sync_rx) = channel(self.sync_channel_size); - let sink = NotificationSink::new(peer, sync_tx, async_tx); - - // start connection handler for the peer which only deals with sending/receiving - // notifications - // - // the connection handler must be started only after the newly opened notification - // substream is reported to user because the connection handler - // might exit immediately after being started if remote closed the connection. - // - // if the order of events (open & close) is not ensured to be correct, the code - // handling the connectivity logic on the `NotificationHandle` side - // might get confused about the current state of the connection. - let shutdown_tx = self.shutdown_tx.clone(); - let (connection, shutdown) = Connection::new( - peer, - inbound, - outbound, - self.event_handle.clone(), - shutdown_tx.clone(), - self.notif_tx.clone(), - async_rx, - sync_rx, - ); - - context.state = PeerState::Open { shutdown }; - self.event_handle - .report_notification_stream_opened( - protocol, - fallback, - direction, - peer, - handshake.into(), - sink, - ) - .await; - - self.executor.run(Box::pin(async move { - connection.start().await; - })); - }, - state => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?state, - "validation for substream still pending", - ); - self.timers.push(Box::pin(async move { - futures_timer::Delay::new(Duration::from_secs(5)).await; - peer - })); - - context.state = state; - }, - } - } - - /// Handle dial failure. - async fn on_dial_failure(&mut self, peer: PeerId, address: Multiaddr) { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?address, - "handle dial failure", - ); - - let Some(context) = self.peers.remove(&peer) else { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?address, - "dial failure for an unknown peer", - ); - return; - }; - - match context.state { - PeerState::Dialing => { - tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, ?address, "failed to dial peer"); - self.event_handle - .report_notification_stream_open_failure(peer, NotificationError::DialFailure) - .await; - }, - state => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?state, - "dial failure for peer that's not being dialed", - ); - self.peers.insert(peer, PeerContext { state }); - }, - } - } - - /// Handle next notification event. - async fn next_event(&mut self) { - // biased select is used because the substream events must be prioritized above other events - // that is becaused a closed substream is detected by either `substreams` or `negotiation` - // and if that event is not handled with priority but, e.g., inbound substream is - // handled before, it can create a situation where the state machine gets confused - // about the peer's state. - tokio::select! { - biased; - - event = self.negotiation.next(), if !self.negotiation.is_empty() => { - let (peer, event) = event.expect("`HandshakeService` to return `Some(..)`"); - self.on_handshake_event(peer, event).await; - } - event = self.shutdown_rx.recv() => match event { - None => return, - Some(peer) => { - if let Some(context) = self.peers.get_mut(&peer) { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "notification stream to peer closed", - ); - context.state = PeerState::Closed { pending_open: None }; - } - } - }, - // TODO: this could be combined with `Negotiation` - peer = self.timers.next(), if !self.timers.is_empty() => match peer { - Some(peer) => { - match self.peers.get_mut(&peer) { - Some(context) => match std::mem::replace(&mut context.state, PeerState::Poisoned) { - PeerState::Validating { - outbound: OutboundState::Open { outbound, .. }, - inbound: InboundState::Closed, - .. - } => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "peer didn't answer in 10 seconds, canceling substream and closing connection", - ); - context.state = PeerState::Closed { pending_open: None }; - - let _ = outbound.close().await; - self.event_handle - .report_notification_stream_open_failure(peer, NotificationError::Rejected) - .await; - - if let Err(error) = self.service.force_close(peer) { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?error, - "failed to force close connection", - ); - } - } - state => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?state, - "ignore expired timer for peer", - ); - context.state = state; - } - } - None => tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - "peer doesn't exist anymore", - ), - } - } - None => return, - }, - event = self.service.next() => match event { - Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - if let Err(error) = self.on_connection_established(peer).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "failed to register peer", - ); - } - } - Some(TransportEvent::ConnectionClosed { peer }) => { - if let Err(error) = self.on_connection_closed(peer).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "failed to disconnect peer", - ); - } - } - Some(TransportEvent::SubstreamOpened { - peer, - substream, - direction, - protocol, - fallback, - }) => match direction { - protocol::Direction::Inbound => { - if let Err(error) = self.on_inbound_substream(protocol, fallback, peer, substream).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "failed to handle inbound substream", - ); - } - } - protocol::Direction::Outbound(substream_id) => { - if let Err(error) = self - .on_outbound_substream(protocol, fallback, peer, substream_id, substream) - .await - { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "failed to handle outbound substream", - ); - } - } - }, - Some(TransportEvent::SubstreamOpenFailure { substream, error }) => { - self.on_substream_open_failure(substream, error).await; - } - Some(TransportEvent::DialFailure { peer, address }) => self.on_dial_failure(peer, address).await, - None => return, - }, - result = self.pending_validations.select_next_some(), if !self.pending_validations.is_empty() => { - if let Err(error) = self.on_validation_result(result.0, result.1).await { - tracing::debug!( - target: LOG_TARGET, - peer = ?result.0, - result = ?result.1, - ?error, - "failed to handle validation result", - ); - } - } - command = self.command_rx.recv() => match command { - None => { - tracing::debug!(target: LOG_TARGET, "user protocol has exited, exiting"); - return - } - Some(command) => match command { - NotificationCommand::OpenSubstream { peers } => { - for peer in peers { - if let Err(error) = self.on_open_substream(peer).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?error, - "failed to open substream", - ); - } - } - } - NotificationCommand::CloseSubstream { peers } => { - for peer in peers { - self.on_close_substream(peer).await; - } - } - NotificationCommand::ForceClose { peer } => { - let _ = self.service.force_close(peer); - } - } - }, - } - } - - /// Start [`NotificationProtocol`] event loop. - pub(crate) async fn run(mut self) { - tracing::debug!(target: LOG_TARGET, "starting notification event loop"); - - loop { - self.next_event().await; - } - } + pub(crate) fn new( + service: TransportService, + config: Config, + executor: Arc, + ) -> Self { + let (shutdown_tx, shutdown_rx) = channel(DEFAULT_CHANNEL_SIZE); + + Self { + service, + shutdown_tx, + shutdown_rx, + executor, + peers: HashMap::new(), + protocol: config.protocol_name, + auto_accept: config.auto_accept, + pending_validations: FuturesUnordered::new(), + timers: FuturesUnordered::new(), + event_handle: NotificationEventHandle::new(config.event_tx), + notif_tx: config.notif_tx, + command_rx: config.command_rx, + pending_outbound: HashMap::new(), + negotiation: HandshakeService::new(config.handshake), + sync_channel_size: config.sync_channel_size, + async_channel_size: config.async_channel_size, + should_dial: config.should_dial, + } + } + + /// Connection established to remote node. + /// + /// If the peer already exists, the only valid state for it is `Dialing` as it indicates that + /// the user tried to open a substream to a peer who was not connected to local node. + /// + /// Any other state indicates that there's an error in the state transition logic. + async fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection established"); + + let Some(context) = self.peers.get_mut(&peer) else { + self.peers.insert(peer, PeerContext::new()); + return Ok(()); + }; + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Dialing => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "dial succeeded, open substream to peer", + ); + + context.state = PeerState::Closed { pending_open: None }; + self.on_open_substream(peer).await + } + // connection established but validation is still pending + // + // update the connection state so that `NotificationProtocol` can proceed + // to correct state after the validation result has beern received + PeerState::ValidationPending { state } => { + debug_assert_eq!(state, ConnectionState::Closed); + + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "new connection established while validation still pending", + ); + + context.state = PeerState::ValidationPending { + state: ConnectionState::Open, + }; + + Ok(()) + } + state => { + tracing::error!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?state, + "state mismatch: peer already exists", + ); + debug_assert!(false); + Err(Error::PeerAlreadyExists(peer)) + } + } + } + + /// Connection closed to remote node. + /// + /// If the connection was considered open (both substreams were open), user is notified that + /// the notification stream was closed. + /// + /// If the connection was still in progress (either substream was not fully open), the user is + /// reported about it only if they had opened an outbound substream (outbound is either fully + /// open, it had been initiated or the substream was under negotiation). + async fn on_connection_closed(&mut self, peer: PeerId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection closed"); + + let Some(context) = self.peers.remove(&peer) else { + tracing::error!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "state mismatch: peer doesn't exist", + ); + debug_assert!(false); + return Err(Error::PeerDoesntExist(peer)); + }; + + // clean up all pending state for the peer + self.negotiation.remove_outbound(&peer); + self.negotiation.remove_inbound(&peer); + + match context.state { + // outbound initiated, report open failure to peer + PeerState::OutboundInitiated { .. } => { + self.event_handle + .report_notification_stream_open_failure(peer, NotificationError::Rejected) + .await; + } + // substream fully open, report that the notification stream is closed + PeerState::Open { shutdown } => { + let _ = shutdown.send(()); + } + // if the substream was being validated, user must be notified that the substream is + // now considered rejected if they had been made aware of the existence of the pending + // connection + PeerState::Validating { + outbound, inbound, .. + } => { + match (outbound, inbound) { + // substream was being validated by the protocol when the connection was closed + (OutboundState::Closed, InboundState::Validating { .. }) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "connection closed while validation pending", + ); + + self.peers.insert( + peer, + PeerContext { + state: PeerState::ValidationPending { + state: ConnectionState::Closed, + }, + }, + ); + } + // user either initiated an outbound substream or an outbound substream was + // opened/being opened as a result of an accepted inbound substream but was not + // yet fully open + // + // to have consistent state tracking in the user protocol, substream rejection + // must be reported to the user + ( + OutboundState::OutboundInitiated { .. } + | OutboundState::Negotiating + | OutboundState::Open { .. }, + _, + ) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "connection closed outbound substream under negotiation", + ); + + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::Rejected, + ) + .await; + } + (_, _) => {} + } + } + // pending validations must be tracked across connection open/close events + PeerState::ValidationPending { .. } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "validation pending while connection closed", + ); + + self.peers.insert( + peer, + PeerContext { + state: PeerState::ValidationPending { + state: ConnectionState::Closed, + }, + }, + ); + } + _ => {} + } + + Ok(()) + } + + /// Local node opened a substream to remote node. + /// + /// The connection can be in three different states: + /// - this is the first substream that was opened and thus the connection was initiated by the + /// local node + /// - this is a response to a previously received inbound substream which the local node + /// accepted and as a result, opened its own substream + /// - local and remote nodes opened substreams at the same time + /// + /// In the first case, the local node's handshake is sent to remote node and the substream is + /// polled in the background until they either send their handshake or close the substream. + /// + /// For the second case, the connection was initiated by the remote node and the substream was + /// accepted by the local node which initiated an outbound substream to the remote node. + /// The only valid states for this case are [`InboundState::Open`], + /// and [`InboundState::SendingHandshake`] as they imply + /// that the inbound substream have been accepted by the local node and this opened outbound + /// substream is a result of a valid state transition. + /// + /// For the third case, if the nodes have opened substreams at the same time, the outbound state + /// must be [`OutboundState::OutboundInitiated`] to ascertain that the an outbound substream was + /// actually opened. Any other state would be a state mismatch and would mean that the + /// connection is opening substreams without the permission of the protocol handler. + async fn on_outbound_substream( + &mut self, + protocol: ProtocolName, + fallback: Option, + peer: PeerId, + substream_id: SubstreamId, + outbound: Substream, + ) -> crate::Result<()> { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?protocol, + ?substream_id, + "handle outbound substream", + ); + + // peer must exist since an outbound substream was received from them + let context = self.peers.get_mut(&peer).expect("peer to exist"); + let pending_peer = self.pending_outbound.remove(&substream_id); + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + // the connection was initiated by the local node, send handshake to remote and wait to + // receive their handshake back + PeerState::OutboundInitiated { substream } => { + debug_assert!(substream == substream_id); + debug_assert!(pending_peer == Some(peer)); + + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?fallback, + ?substream_id, + "negotiate outbound protocol", + ); + + self.negotiation.negotiate_outbound(peer, outbound); + context.state = PeerState::Validating { + protocol, + fallback, + inbound: InboundState::Closed, + outbound: OutboundState::Negotiating, + direction: Direction::Outbound, + }; + } + PeerState::Validating { + protocol, + fallback, + inbound, + direction, + outbound: outbound_state, + } => { + // the inbound substream has been accepted by the local node since the handshake has + // been read and the local handshake has either already been sent or + // it's in the process of being sent. + match inbound { + InboundState::SendingHandshake | InboundState::Open { .. } => { + context.state = PeerState::Validating { + protocol, + fallback, + inbound, + direction, + outbound: OutboundState::Negotiating, + }; + self.negotiation.negotiate_outbound(peer, outbound); + } + // nodes have opened substreams at the same time + inbound_state => match outbound_state { + OutboundState::OutboundInitiated { substream } => { + debug_assert!(substream == substream_id); + + context.state = PeerState::Validating { + protocol, + fallback, + direction, + inbound: inbound_state, + outbound: OutboundState::Negotiating, + }; + self.negotiation.negotiate_outbound(peer, outbound); + } + // invalid state: more than one outbound substream has been opened + inner_state => { + tracing::error!( + target: LOG_TARGET, + ?peer, + %protocol, + ?substream_id, + ?inbound_state, + ?inner_state, + "invalid state, expected `OutboundInitiated`", + ); + + let _ = outbound.close().await; + debug_assert!(false); + } + }, + } + } + // the connection may have been closed while an outbound substream was pending + // if the outbound substream was initiated successfully, close it and reset + // `pending_open` + PeerState::Closed { pending_open } if pending_open == Some(substream_id) => { + let _ = outbound.close().await; + + context.state = PeerState::Closed { pending_open: None }; + } + state => { + tracing::error!( + target: LOG_TARGET, + ?peer, + %protocol, + ?substream_id, + ?state, + "invalid state: more than one outbound substream opened", + ); + + let _ = outbound.close().await; + debug_assert!(false); + } + } + + Ok(()) + } + + /// Remote opened a substream to local node. + /// + /// The peer can be in four different states for the inbound substream to be considered valid: + /// - the connection is closed + /// - conneection is open but substream validation from a previous connection is still pending + /// - outbound substream has been opened but not yet acknowledged by the remote peer + /// - outbound substream has been opened and acknowledged by the remote peer and it's being + /// negotiated + /// + /// If remote opened more than one substream, the new substream is simply discarded. + async fn on_inbound_substream( + &mut self, + protocol: ProtocolName, + fallback: Option, + peer: PeerId, + substream: Substream, + ) -> crate::Result<()> { + // peer must exist since an inbound substream was received from them + let context = self.peers.get_mut(&peer).expect("peer to exist"); + + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + state = ?context.state, + "handle inbound substream", + ); + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + // inbound substream of a previous connection is still pending validation, + // reject any new inbound substreams until an answer is heard from the user + state @ PeerState::ValidationPending { .. } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + ?state, + "validation for previous substream still pending", + ); + + let _ = substream.close().await; + context.state = state; + } + // outbound substream for previous connection still pending, reject inbound substream + // and wait for the outbound substream state to conclude as either succeeded or failed + // before accepting any inbound substreams. + PeerState::Closed { + pending_open: Some(substream_id), + } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "received inbound substream while outbound substream opening, rejecting", + ); + let _ = substream.close().await; + + context.state = PeerState::Closed { + pending_open: Some(substream_id), + }; + } + // the peer state is closed so this is a fresh inbound substream. + PeerState::Closed { pending_open: None } => { + self.negotiation.read_handshake(peer, substream); + + context.state = PeerState::Validating { + protocol, + fallback, + direction: Direction::Inbound, + inbound: InboundState::ReadingHandshake, + outbound: OutboundState::Closed, + }; + } + // if the connection is under validation (so an outbound substream has been opened and + // it's still pending or under negotiation), the only valid state for the + // inbound state is closed as it indicates that there isn't an inbound substream yet for + // the remote node duplicate substreams are prohibited. + PeerState::Validating { + protocol, + fallback, + outbound, + direction, + inbound: InboundState::Closed, + } => { + self.negotiation.read_handshake(peer, substream); + + context.state = PeerState::Validating { + protocol, + fallback, + outbound, + direction, + inbound: InboundState::ReadingHandshake, + }; + } + // outbound substream may have been initiated by the local node while a remote node also + // opened a substream roughly at the same time + PeerState::OutboundInitiated { + substream: outbound, + } => { + self.negotiation.read_handshake(peer, substream); + + context.state = PeerState::Validating { + protocol, + fallback, + direction: Direction::Outbound, + outbound: OutboundState::OutboundInitiated { + substream: outbound, + }, + inbound: InboundState::ReadingHandshake, + }; + } + // new inbound substream opend while validation for the previous substream was still + // pending + // + // the old substream can be considered dead because remote wouldn't open a new substream + // to us unless they had discarded the previous substream. + // + // set peer state to `ValidationPending` to indicate that the peer is "blocked" until a + // validation for the substream is heard, blocking any further activity for + // the connection and once the validation is received and in case the + // substream is accepted, it will be reported as open failure to to the peer + // because the states have gone out of sync. + PeerState::Validating { + outbound: OutboundState::Closed, + inbound: + InboundState::Validating { + inbound: pending_substream, + }, + .. + } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "remote opened substream while previous was still pending, connection failed", + ); + let _ = substream.close().await; + let _ = pending_substream.close().await; + + context.state = PeerState::ValidationPending { + state: ConnectionState::Open, + }; + } + // remote opened another inbound substream, close it and otherwise ignore the event + // as this is a non-serious protocol violation. + state => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + ?state, + "remote opened more than one inbound substreams, discarding", + ); + + let _ = substream.close().await; + context.state = state; + } + } + + Ok(()) + } + + /// Failed to open substream to remote node. + /// + /// If the substream was initiated by the local node, it must be reported that the substream + /// failed to open. Otherwise the peer state can silently be converted to `Closed`. + async fn on_substream_open_failure(&mut self, substream_id: SubstreamId, error: Error) { + tracing::debug!( + target: LOG_TARGET, + protocol = %self.protocol, + ?substream_id, + ?error, + "failed to open substream" + ); + + let Some(peer) = self.pending_outbound.remove(&substream_id) else { + tracing::warn!( + target: LOG_TARGET, + protocol = %self.protocol, + ?substream_id, + "pending outbound substream doesn't exist", + ); + debug_assert!(false); + return; + }; + + // peer must exist since an outbound substream failure was received from them + let Some(context) = self.peers.get_mut(&peer) else { + tracing::warn!(target: LOG_TARGET, ?peer, "peer doesn't exist"); + debug_assert!(false); + return; + }; + + match &mut context.state { + PeerState::OutboundInitiated { .. } => { + context.state = PeerState::Closed { pending_open: None }; + + self.event_handle + .report_notification_stream_open_failure(peer, NotificationError::Rejected) + .await; + } + // if the substream was accepted by the local node and as a result, an outbound + // substream was accepted as a result this should not be reported to local node + PeerState::Validating { outbound, .. } => { + self.negotiation.remove_inbound(&peer); + self.negotiation.remove_outbound(&peer); + + let pending_open = match outbound { + OutboundState::Closed => None, + OutboundState::OutboundInitiated { substream } => { + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::Rejected, + ) + .await; + + Some(*substream) + } + OutboundState::Negotiating | OutboundState::Open { .. } => { + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::Rejected, + ) + .await; + + None + } + }; + + context.state = PeerState::Closed { pending_open }; + } + PeerState::Closed { pending_open } => { + tracing::debug!( + target: LOG_TARGET, + protocol = %self.protocol, + ?substream_id, + "substream open failure for a closed connection", + ); + debug_assert_eq!(pending_open, &Some(substream_id)); + context.state = PeerState::Closed { pending_open: None }; + } + state => { + tracing::warn!( + target: LOG_TARGET, + protocol = %self.protocol, + ?substream_id, + ?state, + "invalid state for outbound substream open failure", + ); + context.state = PeerState::Closed { pending_open: None }; + debug_assert!(false); + } + } + } + + /// Open substream to remote `peer`. + /// + /// Outbound substream can opened only if the `PeerState` is `Closed`. + /// By forcing the substream to be opened only if the state is currently closed, + /// `NotificationProtocol` can enfore more predictable state transitions. + /// + /// Other states either imply an invalid state transition ([`PeerState::Open`]) or that an + /// inbound substream has already been received and its currently being validated by the user. + async fn on_open_substream(&mut self, peer: PeerId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "open substream"); + + let Some(context) = self.peers.get_mut(&peer) else { + if !self.should_dial { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "connection to peer not open and dialing disabled", + ); + + self.event_handle + .report_notification_stream_open_failure(peer, NotificationError::DialFailure) + .await; + return Ok(()); + } + + match self.service.dial(&peer) { + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?error, + "failed to dial peer", + ); + + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::DialFailure, + ) + .await; + + return Err(error); + } + Ok(()) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "started to dial peer", + ); + + self.peers.insert( + peer, + PeerContext { + state: PeerState::Dialing, + }, + ); + return Ok(()); + } + } + }; + + match context.state { + // protocol can only request a new outbound substream to be opened if the state is + // `Closed` other states imply that it's already open + PeerState::Closed { + pending_open: Some(substream_id), + } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?substream_id, + "outbound substream opening, reusing pending open substream", + ); + + self.pending_outbound.insert(substream_id, peer); + context.state = PeerState::OutboundInitiated { + substream: substream_id, + }; + } + PeerState::Closed { .. } => match self.service.open_substream(peer) { + Ok(substream_id) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?substream_id, + "outbound substream opening", + ); + + self.pending_outbound.insert(substream_id, peer); + context.state = PeerState::OutboundInitiated { + substream: substream_id, + }; + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?error, + "failed to open substream", + ); + + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::NoConnection, + ) + .await; + context.state = PeerState::Closed { pending_open: None }; + } + }, + // while a validation is pending for an inbound substream, user is not allowed to open + // any outbound substreams until the old inbond substream is either accepted or rejected + PeerState::ValidationPending { .. } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "validation still pending, rejecting outbound substream request", + ); + + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::ValidationPending, + ) + .await; + } + _ => {} + } + + Ok(()) + } + + /// Close substream to remote `peer`. + /// + /// This function can only be called if the substream was actually open, any other state is + /// unreachable as the user is unable to emit this command to [`NotificationProtocol`] unless + /// the connection has been fully opened. + async fn on_close_substream(&mut self, peer: PeerId) { + tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "close substream"); + + let Some(context) = self.peers.get_mut(&peer) else { + tracing::debug!(target: LOG_TARGET, ?peer, "peer doesn't exist"); + return; + }; + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Open { shutdown } => { + let _ = shutdown.send(()); + + context.state = PeerState::Closed { pending_open: None }; + } + state => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?state, + "substream already closed", + ); + context.state = state; + } + } + } + + /// Handle validation result. + /// + /// The validation result binary (accept/reject). If the node is rejected, the substreams are + /// discarded and state is set to `PeerState::Closed`. If there was an outbound substream in + /// progress while the connection was rejected by the user, the oubound state is discarded, + /// except for the substream ID of the substream which is kept for later use, in case the + /// substream happens to open. + /// + /// If the node is accepted and there is no outbound substream to them open yet, a new substream + /// is opened and once it opens, the local handshake will be sent to the remote peer and if + /// they also accept the substream the connection is considered fully open. + async fn on_validation_result( + &mut self, + peer: PeerId, + result: ValidationResult, + ) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?result, + "handle validation result", + ); + + let Some(context) = self.peers.get_mut(&peer) else { + tracing::debug!(target: LOG_TARGET, ?peer, "peer doesn't exist"); + return Err(Error::PeerDoesntExist(peer)); + }; + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Validating { + protocol, + fallback, + outbound, + direction, + inbound: InboundState::Validating { inbound }, + } => match result { + // substream was rejected by the local node, if an outbound substream was under + // negotation, discard that data and if an outbound substream was + // initiated, save the `SubstreamId` of that substream and later if the substream + // is opened, the state can be corrected to `pending_open: None`. + ValidationResult::Reject => { + let _ = inbound.close().await; + self.negotiation.remove_outbound(&peer); + self.negotiation.remove_inbound(&peer); + context.state = PeerState::Closed { + pending_open: outbound.pending_open(), + }; + + Ok(()) + } + ValidationResult::Accept => match outbound { + // no outbound substream exists so initiate a new substream open and send the + // local handshake to remote node, indicating that the + // connection was accepted by the local node + OutboundState::Closed => match self.service.open_substream(peer) { + Ok(substream) => { + self.negotiation.send_handshake(peer, inbound); + self.pending_outbound.insert(substream, peer); + + context.state = PeerState::Validating { + protocol, + fallback, + direction, + inbound: InboundState::SendingHandshake, + outbound: OutboundState::OutboundInitiated { substream }, + }; + Ok(()) + } + // failed to open outbound substream after accepting an inbound substream + // + // since the user was notified of this substream and they accepted it, + // they expecting some kind of answer (open success/failure). + // + // report to user that the substream failed to open so they can track the + // state transitions of the peer correctly + Err(error) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?result, + ?error, + "failed to open outbound substream for accepted substream", + ); + + let _ = inbound.close().await; + context.state = PeerState::Closed { pending_open: None }; + + self.event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::Rejected, + ) + .await; + + Err(error) + } + }, + // here the state is one of `OutboundState::{OutboundInitiated, Negotiating, + // Open}` so that state can be safely ignored and all that + // has to be done is to send the local handshake to remote + // node to indicate that the connection was accepted. + _ => { + self.negotiation.send_handshake(peer, inbound); + + context.state = PeerState::Validating { + protocol, + fallback, + direction, + inbound: InboundState::SendingHandshake, + outbound, + }; + Ok(()) + } + }, + }, + // validation result received for an inbound substream which is now considered dead + // because while the substream was being validated, the connection had closed. + // + // if the substream was rejected and there is no active connection to the peer, + // just remove the peer from `peers` without informing user + // + // if the substream was accepted, the user must be informed that the substream failed to + // open. Depending on whether there is currently a connection open to the peer, either + // report `Rejected`/`NoConnection` and let the user try again. + PeerState::ValidationPending { state } => { + if let Some(error) = match state { + ConnectionState::Open => { + context.state = PeerState::Closed { pending_open: None }; + + std::matches!(result, ValidationResult::Accept) + .then_some(NotificationError::Rejected) + } + ConnectionState::Closed => { + self.peers.remove(&peer); + + std::matches!(result, ValidationResult::Accept) + .then_some(NotificationError::NoConnection) + } + } { + self.event_handle.report_notification_stream_open_failure(peer, error).await; + } + + Ok(()) + } + // if the user incorrectly send a validation result for a peer that doesn't require + // validation, set state back to what it was and ignore the event + // + // the user protocol may send a stale validation result not because of a programming + // error but because it has a backlock of unhandled events, with one event potentially + // nullifying the need for substream validation, and is just temporarily out of sync + // with `NotificationProtocol` + state => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?state, + "validation result received for peer that doesn't require validation", + ); + + context.state = state; + Ok(()) + } + } + } + + /// Handle handshake event. + /// + /// There are three different handshake event types: + /// - outbound substream negotiated + /// - inbound substream negotiated + /// - substream negotiation error + /// + /// Neither outbound nor inbound substream negotiated automatically means that the connection is + /// considered open as both substreams must be fully negotiated for that to be the case. That is + /// why the peer state for inbound and outbound are set separately and at the end of the + /// function is the collective state of the substreams checked and if both substreams are + /// negotiated, the user informed that the connection is open. + /// + /// If the negotiation fails, the user may have to be informed of that. Outbound substream + /// failure always results in user getting notified since the existence of an outbound substream + /// means that the user has either initiated an outbound substreams or has accepted an inbound + /// substreams, resulting in an outbound substreams. + /// + /// Negotiation failure for inbound substreams which are in the state + /// [`InboundState::ReadingHandshake`] don't result in any notification because while the + /// handshake is being read from the substream, the user is oblivious to the fact that an + /// inbound substream has even been received. + async fn on_handshake_event(&mut self, peer: PeerId, event: HandshakeEvent) { + let Some(context) = self.peers.get_mut(&peer) else { + tracing::error!(target: LOG_TARGET, "invalid state: negotiation event received but peer doesn't exist"); + debug_assert!(false); + return; + }; + + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?event, + "handle handshake event", + ); + + match event { + // either an inbound or outbound substream has been negotiated successfully + HandshakeEvent::Negotiated { + peer, + handshake, + substream, + direction, + } => match direction { + // outbound substream was negotiated, the only valid state for peer is `Validating` + // and only valid state for `OutboundState` is `Negotiating` + negotiation::Direction::Outbound => { + self.negotiation.remove_outbound(&peer); + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Validating { + protocol, + fallback, + direction, + outbound: OutboundState::Negotiating, + inbound, + } => { + context.state = PeerState::Validating { + protocol, + fallback, + direction, + outbound: OutboundState::Open { + handshake, + outbound: substream, + }, + inbound, + }; + } + state => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?state, + "outbound substream negotiated but peer has invalid state", + ); + debug_assert!(false); + } + } + } + // inbound negotiation event completed + // + // the negotiation event can be on of two different types: + // - remote handshake was read from the substream + // - local handshake has been sent to remote node + // + // For the first case, the substream has to be validated by the local node. + // This means reporting the protocol name, potential negotiated fallback and the + // handshake. Local node will then either accept or reject the substream which is + // handled by [`NotificationProtocol::on_validation_result()`]. Compared to + // Substrate, litep2p requires both peers to validate the inbound handshake to allow + // more complex connection validation. If this is not necessary and the protocol + // wishes to auto-accept the inbound substreams that are a result of + // an outbound substream already accepted by the remote node, the + // substream validation is skipped and the local handshake is sent + // right away. + // + // For the second case, the local handshake was sent to remote node successfully and + // the inbound substream is considered open and if the outbound + // substream is open as well, the connection is fully open. + // + // Only valid states for [`InboundState`] are [`InboundState::ReadingHandshake`] and + // [`InboundState::SendingHandshake`] because otherwise the inbound + // substream cannot be in [`HandshakeService`](super::negotiation::HandshakeService) + // unless there is a logic bug in the state machine. + negotiation::Direction::Inbound => { + self.negotiation.remove_inbound(&peer); + + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Validating { + protocol, + fallback, + direction, + outbound, + inbound: InboundState::ReadingHandshake, + } => { + if !std::matches!(outbound, OutboundState::Closed) && self.auto_accept { + tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + ?direction, + ?outbound, + "auto-accept inbound substream", + ); + + self.negotiation.send_handshake(peer, substream); + context.state = PeerState::Validating { + protocol, + fallback, + direction, + inbound: InboundState::SendingHandshake, + outbound, + }; + + return; + } + + tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + ?outbound, + "send inbound protocol for validation", + ); + + context.state = PeerState::Validating { + protocol: protocol.clone(), + fallback: fallback.clone(), + inbound: InboundState::Validating { inbound: substream }, + outbound, + direction, + }; + + let (tx, rx) = oneshot::channel(); + self.pending_validations.push(Box::pin(async move { + match rx.await { + Ok(ValidationResult::Accept) => + (peer, ValidationResult::Accept), + _ => (peer, ValidationResult::Reject), + } + })); + + self.event_handle + .report_inbound_substream( + protocol, + fallback, + peer, + handshake.into(), + tx, + ) + .await; + } + PeerState::Validating { + protocol, + fallback, + direction, + inbound: InboundState::SendingHandshake, + outbound, + } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + "inbound substream negotiated, waiting for outbound substream to complete", + ); + + context.state = PeerState::Validating { + protocol: protocol.clone(), + fallback: fallback.clone(), + inbound: InboundState::Open { inbound: substream }, + outbound, + direction, + }; + } + _state => debug_assert!(false), + } + } + }, + // error occurred during negotiation, eitehr for inbound or outbound substream + // user is notified of the error only if they've either initiated an outbound substream + // or if they accepted an inbound substream and as a result initiated an outbound + // substream. + HandshakeEvent::NegotiationError { peer, direction } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?direction, + state = ?context.state, + "failed to negotiate substream", + ); + let _ = self.negotiation.remove_outbound(&peer); + let _ = self.negotiation.remove_inbound(&peer); + + // if an outbound substream had been initiated (whatever its state is), it means + // that the user knows about the connection and must be notified that it failed to + // negotiate. + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Validating { outbound, .. } => { + context.state = PeerState::Closed { + pending_open: outbound.pending_open(), + }; + + // notify user if the outbound substream is not considered closed + if !std::matches!(outbound, OutboundState::Closed) { + return self + .event_handle + .report_notification_stream_open_failure( + peer, + NotificationError::Rejected, + ) + .await; + } + } + _state => debug_assert!(false), + } + } + } + + // if both inbound and outbound substreams are considered open, notify the user that + // a notification stream has been opened and set up for sending and receiving + // notifications to and from remote node + match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Validating { + protocol, + fallback, + direction, + outbound: + OutboundState::Open { + handshake, + outbound, + }, + inbound: InboundState::Open { inbound }, + } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?fallback, + "notification stream opened", + ); + + let (async_tx, async_rx) = channel(self.async_channel_size); + let (sync_tx, sync_rx) = channel(self.sync_channel_size); + let sink = NotificationSink::new(peer, sync_tx, async_tx); + + // start connection handler for the peer which only deals with sending/receiving + // notifications + // + // the connection handler must be started only after the newly opened notification + // substream is reported to user because the connection handler + // might exit immediately after being started if remote closed the connection. + // + // if the order of events (open & close) is not ensured to be correct, the code + // handling the connectivity logic on the `NotificationHandle` side + // might get confused about the current state of the connection. + let shutdown_tx = self.shutdown_tx.clone(); + let (connection, shutdown) = Connection::new( + peer, + inbound, + outbound, + self.event_handle.clone(), + shutdown_tx.clone(), + self.notif_tx.clone(), + async_rx, + sync_rx, + ); + + context.state = PeerState::Open { shutdown }; + self.event_handle + .report_notification_stream_opened( + protocol, + fallback, + direction, + peer, + handshake.into(), + sink, + ) + .await; + + self.executor.run(Box::pin(async move { + connection.start().await; + })); + } + state => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?state, + "validation for substream still pending", + ); + self.timers.push(Box::pin(async move { + futures_timer::Delay::new(Duration::from_secs(5)).await; + peer + })); + + context.state = state; + } + } + } + + /// Handle dial failure. + async fn on_dial_failure(&mut self, peer: PeerId, address: Multiaddr) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?address, + "handle dial failure", + ); + + let Some(context) = self.peers.remove(&peer) else { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?address, + "dial failure for an unknown peer", + ); + return; + }; + + match context.state { + PeerState::Dialing => { + tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, ?address, "failed to dial peer"); + self.event_handle + .report_notification_stream_open_failure(peer, NotificationError::DialFailure) + .await; + } + state => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?state, + "dial failure for peer that's not being dialed", + ); + self.peers.insert(peer, PeerContext { state }); + } + } + } + + /// Handle next notification event. + async fn next_event(&mut self) { + // biased select is used because the substream events must be prioritized above other events + // that is becaused a closed substream is detected by either `substreams` or `negotiation` + // and if that event is not handled with priority but, e.g., inbound substream is + // handled before, it can create a situation where the state machine gets confused + // about the peer's state. + tokio::select! { + biased; + + event = self.negotiation.next(), if !self.negotiation.is_empty() => { + let (peer, event) = event.expect("`HandshakeService` to return `Some(..)`"); + self.on_handshake_event(peer, event).await; + } + event = self.shutdown_rx.recv() => match event { + None => return, + Some(peer) => { + if let Some(context) = self.peers.get_mut(&peer) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "notification stream to peer closed", + ); + context.state = PeerState::Closed { pending_open: None }; + } + } + }, + // TODO: this could be combined with `Negotiation` + peer = self.timers.next(), if !self.timers.is_empty() => match peer { + Some(peer) => { + match self.peers.get_mut(&peer) { + Some(context) => match std::mem::replace(&mut context.state, PeerState::Poisoned) { + PeerState::Validating { + outbound: OutboundState::Open { outbound, .. }, + inbound: InboundState::Closed, + .. + } => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "peer didn't answer in 10 seconds, canceling substream and closing connection", + ); + context.state = PeerState::Closed { pending_open: None }; + + let _ = outbound.close().await; + self.event_handle + .report_notification_stream_open_failure(peer, NotificationError::Rejected) + .await; + + if let Err(error) = self.service.force_close(peer) { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?error, + "failed to force close connection", + ); + } + } + state => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?state, + "ignore expired timer for peer", + ); + context.state = state; + } + } + None => tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + "peer doesn't exist anymore", + ), + } + } + None => return, + }, + event = self.service.next() => match event { + Some(TransportEvent::ConnectionEstablished { peer, .. }) => { + if let Err(error) = self.on_connection_established(peer).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "failed to register peer", + ); + } + } + Some(TransportEvent::ConnectionClosed { peer }) => { + if let Err(error) = self.on_connection_closed(peer).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "failed to disconnect peer", + ); + } + } + Some(TransportEvent::SubstreamOpened { + peer, + substream, + direction, + protocol, + fallback, + }) => match direction { + protocol::Direction::Inbound => { + if let Err(error) = self.on_inbound_substream(protocol, fallback, peer, substream).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "failed to handle inbound substream", + ); + } + } + protocol::Direction::Outbound(substream_id) => { + if let Err(error) = self + .on_outbound_substream(protocol, fallback, peer, substream_id, substream) + .await + { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "failed to handle outbound substream", + ); + } + } + }, + Some(TransportEvent::SubstreamOpenFailure { substream, error }) => { + self.on_substream_open_failure(substream, error).await; + } + Some(TransportEvent::DialFailure { peer, address }) => self.on_dial_failure(peer, address).await, + None => return, + }, + result = self.pending_validations.select_next_some(), if !self.pending_validations.is_empty() => { + if let Err(error) = self.on_validation_result(result.0, result.1).await { + tracing::debug!( + target: LOG_TARGET, + peer = ?result.0, + result = ?result.1, + ?error, + "failed to handle validation result", + ); + } + } + command = self.command_rx.recv() => match command { + None => { + tracing::debug!(target: LOG_TARGET, "user protocol has exited, exiting"); + return + } + Some(command) => match command { + NotificationCommand::OpenSubstream { peers } => { + for peer in peers { + if let Err(error) = self.on_open_substream(peer).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?error, + "failed to open substream", + ); + } + } + } + NotificationCommand::CloseSubstream { peers } => { + for peer in peers { + self.on_close_substream(peer).await; + } + } + NotificationCommand::ForceClose { peer } => { + let _ = self.service.force_close(peer); + } + } + }, + } + } + + /// Start [`NotificationProtocol`] event loop. + pub(crate) async fn run(mut self) { + tracing::debug!(target: LOG_TARGET, "starting notification event loop"); + + loop { + self.next_event().await; + } + } } diff --git a/src/protocol/notification/negotiation.rs b/src/protocol/notification/negotiation.rs index 0026467f..e76fe40a 100644 --- a/src/protocol/notification/negotiation.rs +++ b/src/protocol/notification/negotiation.rs @@ -27,11 +27,11 @@ use futures_timer::Delay; use parking_lot::RwLock; use std::{ - collections::{HashMap, VecDeque}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - time::Duration, + collections::{HashMap, VecDeque}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, }; /// Logging target for the file. @@ -43,375 +43,410 @@ const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(10); /// Substream direction. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Direction { - /// Outbound substream, opened by local node. - Outbound, + /// Outbound substream, opened by local node. + Outbound, - /// Inbound substream, opened by remote node. - Inbound, + /// Inbound substream, opened by remote node. + Inbound, } /// Events emitted by [`HandshakeService`]. #[derive(Debug)] pub enum HandshakeEvent { - /// Substream has been negotiated. - Negotiated { - /// Peer ID. - peer: PeerId, + /// Substream has been negotiated. + Negotiated { + /// Peer ID. + peer: PeerId, - /// Handshake. - handshake: Vec, + /// Handshake. + handshake: Vec, - /// Substream. - substream: Substream, + /// Substream. + substream: Substream, - /// Direction. - direction: Direction, - }, + /// Direction. + direction: Direction, + }, - /// Outbound substream has been negotiated. - NegotiationError { - /// Peer ID. - peer: PeerId, + /// Outbound substream has been negotiated. + NegotiationError { + /// Peer ID. + peer: PeerId, - /// Direction. - direction: Direction, - }, + /// Direction. + direction: Direction, + }, } /// Outbound substream's handshake state enum HandshakeState { - /// Send handshake to remote peer. - SendHandshake, + /// Send handshake to remote peer. + SendHandshake, - /// Sink is ready for the handshake to be sent. - SinkReady, + /// Sink is ready for the handshake to be sent. + SinkReady, - /// Handshake has been sent. - HandshakeSent, + /// Handshake has been sent. + HandshakeSent, - /// Read handshake from remote peer. - ReadHandshake, + /// Read handshake from remote peer. + ReadHandshake, } /// Handshake service. pub(crate) struct HandshakeService { - /// Handshake. - handshake: Arc>>, + /// Handshake. + handshake: Arc>>, - /// Pending outbound substreams. - /// Substreams: - substreams: HashMap<(PeerId, Direction), (Substream, Delay, HandshakeState)>, + /// Pending outbound substreams. + /// Substreams: + substreams: HashMap<(PeerId, Direction), (Substream, Delay, HandshakeState)>, - /// Ready substreams. - ready: VecDeque<(PeerId, Direction, Vec)>, + /// Ready substreams. + ready: VecDeque<(PeerId, Direction, Vec)>, } impl HandshakeService { - /// Create new [`HandshakeService`]. - pub fn new(handshake: Arc>>) -> Self { - Self { handshake, ready: VecDeque::new(), substreams: HashMap::new() } - } - - /// Remove outbound substream from [`HandshakeService`]. - pub fn remove_outbound(&mut self, peer: &PeerId) -> Option { - self.substreams - .remove(&(*peer, Direction::Outbound)) - .map(|(substream, _, _)| substream) - } - - /// Remove inbound substream from [`HandshakeService`]. - pub fn remove_inbound(&mut self, peer: &PeerId) -> Option { - self.substreams - .remove(&(*peer, Direction::Inbound)) - .map(|(substream, _, _)| substream) - } - - /// Negotiate outbound handshake. - pub fn negotiate_outbound(&mut self, peer: PeerId, substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "negotiate outbound"); - - self.substreams.insert( - (peer, Direction::Outbound), - (substream, Delay::new(NEGOTIATION_TIMEOUT), HandshakeState::SendHandshake), - ); - } - - /// Read handshake from remote peer. - pub fn read_handshake(&mut self, peer: PeerId, substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "read handshake"); - - self.substreams.insert( - (peer, Direction::Inbound), - (substream, Delay::new(NEGOTIATION_TIMEOUT), HandshakeState::ReadHandshake), - ); - } - - /// Write handshake to remote peer. - pub fn send_handshake(&mut self, peer: PeerId, substream: Substream) { - tracing::trace!(target: LOG_TARGET, ?peer, "send handshake"); - - self.substreams.insert( - (peer, Direction::Inbound), - (substream, Delay::new(NEGOTIATION_TIMEOUT), HandshakeState::SendHandshake), - ); - } - - /// Returns `true` if [`HandshakeService`] contains no elements. - pub fn is_empty(&self) -> bool { - self.substreams.is_empty() - } - - /// Pop event from the event queue. - /// - /// The substream may not exist in the queue anymore as it may have been removed - /// by `NotificationProtocol` if either one of the substreams failed to negotiate. - fn pop_event(&mut self) -> Option<(PeerId, HandshakeEvent)> { - while let Some((peer, direction, handshake)) = self.ready.pop_front() { - if let Some((substream, _, _)) = self.substreams.remove(&(peer, direction)) { - return Some(( - peer, - HandshakeEvent::Negotiated { peer, handshake, substream, direction }, - )); - } - } - - return None; - } + /// Create new [`HandshakeService`]. + pub fn new(handshake: Arc>>) -> Self { + Self { + handshake, + ready: VecDeque::new(), + substreams: HashMap::new(), + } + } + + /// Remove outbound substream from [`HandshakeService`]. + pub fn remove_outbound(&mut self, peer: &PeerId) -> Option { + self.substreams + .remove(&(*peer, Direction::Outbound)) + .map(|(substream, _, _)| substream) + } + + /// Remove inbound substream from [`HandshakeService`]. + pub fn remove_inbound(&mut self, peer: &PeerId) -> Option { + self.substreams + .remove(&(*peer, Direction::Inbound)) + .map(|(substream, _, _)| substream) + } + + /// Negotiate outbound handshake. + pub fn negotiate_outbound(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "negotiate outbound"); + + self.substreams.insert( + (peer, Direction::Outbound), + ( + substream, + Delay::new(NEGOTIATION_TIMEOUT), + HandshakeState::SendHandshake, + ), + ); + } + + /// Read handshake from remote peer. + pub fn read_handshake(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "read handshake"); + + self.substreams.insert( + (peer, Direction::Inbound), + ( + substream, + Delay::new(NEGOTIATION_TIMEOUT), + HandshakeState::ReadHandshake, + ), + ); + } + + /// Write handshake to remote peer. + pub fn send_handshake(&mut self, peer: PeerId, substream: Substream) { + tracing::trace!(target: LOG_TARGET, ?peer, "send handshake"); + + self.substreams.insert( + (peer, Direction::Inbound), + ( + substream, + Delay::new(NEGOTIATION_TIMEOUT), + HandshakeState::SendHandshake, + ), + ); + } + + /// Returns `true` if [`HandshakeService`] contains no elements. + pub fn is_empty(&self) -> bool { + self.substreams.is_empty() + } + + /// Pop event from the event queue. + /// + /// The substream may not exist in the queue anymore as it may have been removed + /// by `NotificationProtocol` if either one of the substreams failed to negotiate. + fn pop_event(&mut self) -> Option<(PeerId, HandshakeEvent)> { + while let Some((peer, direction, handshake)) = self.ready.pop_front() { + if let Some((substream, _, _)) = self.substreams.remove(&(peer, direction)) { + return Some(( + peer, + HandshakeEvent::Negotiated { + peer, + handshake, + substream, + direction, + }, + )); + } + } + + return None; + } } impl Stream for HandshakeService { - type Item = (PeerId, HandshakeEvent); - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let inner = Pin::into_inner(self); - - if let Some(event) = inner.pop_event() { - return Poll::Ready(Some(event)); - } - - if inner.substreams.is_empty() { - return Poll::Pending; - } - - 'outer: for ((peer, direction), (ref mut substream, ref mut timer, state)) in - inner.substreams.iter_mut() - { - if let Poll::Ready(()) = timer.poll_unpin(cx) { - return Poll::Ready(Some(( - *peer, - HandshakeEvent::NegotiationError { peer: *peer, direction: *direction }, - ))); - } - - loop { - let pinned = Pin::new(&mut *substream); - - match state { - HandshakeState::SendHandshake => match pinned.poll_ready(cx) { - Poll::Ready(Ok(())) => { - *state = HandshakeState::SinkReady; - continue; - }, - Poll::Ready(Err(_)) => - return Poll::Ready(Some(( - *peer, - HandshakeEvent::NegotiationError { - peer: *peer, - direction: *direction, - }, - ))), - Poll::Pending => continue 'outer, - }, - HandshakeState::SinkReady => { - match pinned.start_send((*inner.handshake.read()).clone().into()) { - Ok(()) => { - *state = HandshakeState::HandshakeSent; - continue; - }, - Err(_) => - return Poll::Ready(Some(( - *peer, - HandshakeEvent::NegotiationError { - peer: *peer, - direction: *direction, - }, - ))), - } - }, - HandshakeState::HandshakeSent => match pinned.poll_flush(cx) { - Poll::Ready(Ok(())) => match direction { - Direction::Outbound => { - *state = HandshakeState::ReadHandshake; - continue; - }, - Direction::Inbound => { - inner.ready.push_back((*peer, *direction, vec![])); - continue 'outer; - }, - }, - Poll::Ready(Err(_)) => - return Poll::Ready(Some(( - *peer, - HandshakeEvent::NegotiationError { - peer: *peer, - direction: *direction, - }, - ))), - Poll::Pending => continue 'outer, - }, - HandshakeState::ReadHandshake => match pinned.poll_next(cx) { - Poll::Ready(Some(Ok(handshake))) => { - inner.ready.push_back((*peer, *direction, handshake.freeze().into())); - continue 'outer; - }, - Poll::Ready(Some(Err(_))) | Poll::Ready(None) => { - return Poll::Ready(Some(( - *peer, - HandshakeEvent::NegotiationError { - peer: *peer, - direction: *direction, - }, - ))); - }, - Poll::Pending => continue 'outer, - }, - } - } - } - - if let Some((peer, direction, handshake)) = inner.ready.pop_front() { - let (substream, _, _) = - inner.substreams.remove(&(peer, direction)).expect("peer to exist"); - - return Poll::Ready(Some(( - peer, - HandshakeEvent::Negotiated { peer, handshake, substream, direction }, - ))); - } - - Poll::Pending - } + type Item = (PeerId, HandshakeEvent); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inner = Pin::into_inner(self); + + if let Some(event) = inner.pop_event() { + return Poll::Ready(Some(event)); + } + + if inner.substreams.is_empty() { + return Poll::Pending; + } + + 'outer: for ((peer, direction), (ref mut substream, ref mut timer, state)) in + inner.substreams.iter_mut() + { + if let Poll::Ready(()) = timer.poll_unpin(cx) { + return Poll::Ready(Some(( + *peer, + HandshakeEvent::NegotiationError { + peer: *peer, + direction: *direction, + }, + ))); + } + + loop { + let pinned = Pin::new(&mut *substream); + + match state { + HandshakeState::SendHandshake => match pinned.poll_ready(cx) { + Poll::Ready(Ok(())) => { + *state = HandshakeState::SinkReady; + continue; + } + Poll::Ready(Err(_)) => + return Poll::Ready(Some(( + *peer, + HandshakeEvent::NegotiationError { + peer: *peer, + direction: *direction, + }, + ))), + Poll::Pending => continue 'outer, + }, + HandshakeState::SinkReady => { + match pinned.start_send((*inner.handshake.read()).clone().into()) { + Ok(()) => { + *state = HandshakeState::HandshakeSent; + continue; + } + Err(_) => + return Poll::Ready(Some(( + *peer, + HandshakeEvent::NegotiationError { + peer: *peer, + direction: *direction, + }, + ))), + } + } + HandshakeState::HandshakeSent => match pinned.poll_flush(cx) { + Poll::Ready(Ok(())) => match direction { + Direction::Outbound => { + *state = HandshakeState::ReadHandshake; + continue; + } + Direction::Inbound => { + inner.ready.push_back((*peer, *direction, vec![])); + continue 'outer; + } + }, + Poll::Ready(Err(_)) => + return Poll::Ready(Some(( + *peer, + HandshakeEvent::NegotiationError { + peer: *peer, + direction: *direction, + }, + ))), + Poll::Pending => continue 'outer, + }, + HandshakeState::ReadHandshake => match pinned.poll_next(cx) { + Poll::Ready(Some(Ok(handshake))) => { + inner.ready.push_back((*peer, *direction, handshake.freeze().into())); + continue 'outer; + } + Poll::Ready(Some(Err(_))) | Poll::Ready(None) => { + return Poll::Ready(Some(( + *peer, + HandshakeEvent::NegotiationError { + peer: *peer, + direction: *direction, + }, + ))); + } + Poll::Pending => continue 'outer, + }, + } + } + } + + if let Some((peer, direction, handshake)) = inner.ready.pop_front() { + let (substream, _, _) = + inner.substreams.remove(&(peer, direction)).expect("peer to exist"); + + return Poll::Ready(Some(( + peer, + HandshakeEvent::Negotiated { + peer, + handshake, + substream, + direction, + }, + ))); + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - mock::substream::{DummySubstream, MockSubstream}, - types::SubstreamId, - Error, - }; - use futures::StreamExt; - - #[tokio::test] - async fn substream_error_when_sending_handshake() { - let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4]))); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event received"), - }) - .await; - - let mut substream = MockSubstream::new(); - substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream.expect_start_send().times(1).return_once(|_| Err(Error::Unknown)); - - let peer = PeerId::random(); - let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)); - - service.send_handshake(peer, substream); - match service.next().await { - Some(( - failed_peer, - HandshakeEvent::NegotiationError { peer: event_peer, direction }, - )) => { - assert_eq!(failed_peer, peer); - assert_eq!(event_peer, peer); - assert_eq!(direction, Direction::Inbound); - }, - _ => panic!("invalid event received"), - } - } - - #[tokio::test] - async fn substream_error_when_flushing_substream() { - let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4]))); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event received"), - }) - .await; - - let mut substream = MockSubstream::new(); - substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream.expect_start_send().times(1).return_once(|_| Ok(())); - substream - .expect_poll_flush() - .times(1) - .return_once(|_| Poll::Ready(Err(Error::Unknown))); - - let peer = PeerId::random(); - let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)); - - service.send_handshake(peer, substream); - match service.next().await { - Some(( - failed_peer, - HandshakeEvent::NegotiationError { peer: event_peer, direction }, - )) => { - assert_eq!(failed_peer, peer); - assert_eq!(event_peer, peer); - assert_eq!(direction, Direction::Inbound); - }, - _ => panic!("invalid event received"), - } - } - - // inbound substream is negotiated and it pushed into `inner` but outbound substream fails to - // negotiate - #[tokio::test] - async fn pop_event_but_substream_doesnt_exist() { - let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4]))); - let peer = PeerId::random(); - - // inbound substream has finished - service.ready.push_front((peer, Direction::Inbound, vec![])); - service.substreams.insert( - (peer, Direction::Inbound), - ( - Substream::new_mock( - peer, - SubstreamId::from(1337usize), - Box::new(DummySubstream::new()), - ), - Delay::new(NEGOTIATION_TIMEOUT), - HandshakeState::HandshakeSent, - ), - ); - service.substreams.insert( - (peer, Direction::Outbound), - ( - Substream::new_mock( - peer, - SubstreamId::from(1337usize), - Box::new(DummySubstream::new()), - ), - Delay::new(NEGOTIATION_TIMEOUT), - HandshakeState::SendHandshake, - ), - ); - - // outbound substream failed and `NotificationProtocol` removes - // both substreams from `HandshakeService` - assert!(service.remove_outbound(&peer).is_some()); - assert!(service.remove_inbound(&peer).is_some()); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event received"), - }) - .await - } + use super::*; + use crate::{ + mock::substream::{DummySubstream, MockSubstream}, + types::SubstreamId, + Error, + }; + use futures::StreamExt; + + #[tokio::test] + async fn substream_error_when_sending_handshake() { + let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4]))); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event received"), + }) + .await; + + let mut substream = MockSubstream::new(); + substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream.expect_start_send().times(1).return_once(|_| Err(Error::Unknown)); + + let peer = PeerId::random(); + let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)); + + service.send_handshake(peer, substream); + match service.next().await { + Some(( + failed_peer, + HandshakeEvent::NegotiationError { + peer: event_peer, + direction, + }, + )) => { + assert_eq!(failed_peer, peer); + assert_eq!(event_peer, peer); + assert_eq!(direction, Direction::Inbound); + } + _ => panic!("invalid event received"), + } + } + + #[tokio::test] + async fn substream_error_when_flushing_substream() { + let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4]))); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event received"), + }) + .await; + + let mut substream = MockSubstream::new(); + substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream.expect_start_send().times(1).return_once(|_| Ok(())); + substream + .expect_poll_flush() + .times(1) + .return_once(|_| Poll::Ready(Err(Error::Unknown))); + + let peer = PeerId::random(); + let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)); + + service.send_handshake(peer, substream); + match service.next().await { + Some(( + failed_peer, + HandshakeEvent::NegotiationError { + peer: event_peer, + direction, + }, + )) => { + assert_eq!(failed_peer, peer); + assert_eq!(event_peer, peer); + assert_eq!(direction, Direction::Inbound); + } + _ => panic!("invalid event received"), + } + } + + // inbound substream is negotiated and it pushed into `inner` but outbound substream fails to + // negotiate + #[tokio::test] + async fn pop_event_but_substream_doesnt_exist() { + let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4]))); + let peer = PeerId::random(); + + // inbound substream has finished + service.ready.push_front((peer, Direction::Inbound, vec![])); + service.substreams.insert( + (peer, Direction::Inbound), + ( + Substream::new_mock( + peer, + SubstreamId::from(1337usize), + Box::new(DummySubstream::new()), + ), + Delay::new(NEGOTIATION_TIMEOUT), + HandshakeState::HandshakeSent, + ), + ); + service.substreams.insert( + (peer, Direction::Outbound), + ( + Substream::new_mock( + peer, + SubstreamId::from(1337usize), + Box::new(DummySubstream::new()), + ), + Delay::new(NEGOTIATION_TIMEOUT), + HandshakeState::SendHandshake, + ), + ); + + // outbound substream failed and `NotificationProtocol` removes + // both substreams from `HandshakeService` + assert!(service.remove_outbound(&peer).is_some()); + assert!(service.remove_inbound(&peer).is_some()); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event received"), + }) + .await + } } diff --git a/src/protocol/notification/tests/mod.rs b/src/protocol/notification/tests/mod.rs index 66131aee..43dd9121 100644 --- a/src/protocol/notification/tests/mod.rs +++ b/src/protocol/notification/tests/mod.rs @@ -21,17 +21,17 @@ use std::collections::HashSet; use crate::{ - crypto::ed25519::Keypair, - executor::DefaultExecutor, - protocol::{ - notification::{ - handle::NotificationHandle, Config as NotificationConfig, NotificationProtocol, - }, - InnerTransportEvent, ProtocolCommand, TransportService, - }, - transport::manager::TransportManager, - types::protocol::ProtocolName, - BandwidthSink, PeerId, + crypto::ed25519::Keypair, + executor::DefaultExecutor, + protocol::{ + notification::{ + handle::NotificationHandle, Config as NotificationConfig, NotificationProtocol, + }, + InnerTransportEvent, ProtocolCommand, TransportService, + }, + transport::manager::TransportManager, + types::protocol::ProtocolName, + BandwidthSink, PeerId, }; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -42,45 +42,53 @@ mod notification; mod substream_validation; /// create new `NotificationProtocol` -fn make_notification_protocol( -) -> (NotificationProtocol, NotificationHandle, TransportManager, Sender) { - let (manager, handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), BandwidthSink::new(), 8usize); +fn make_notification_protocol() -> ( + NotificationProtocol, + NotificationHandle, + TransportManager, + Sender, +) { + let (manager, handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); - let peer = PeerId::random(); - let (transport_service, tx) = TransportService::new( - peer, - ProtocolName::from("/notif/1"), - Vec::new(), - std::sync::Arc::new(Default::default()), - handle, - ); - let (config, handle) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); + let peer = PeerId::random(); + let (transport_service, tx) = TransportService::new( + peer, + ProtocolName::from("/notif/1"), + Vec::new(), + std::sync::Arc::new(Default::default()), + handle, + ); + let (config, handle) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); - ( - NotificationProtocol::new( - transport_service, - config, - std::sync::Arc::new(DefaultExecutor {}), - ), - handle, - manager, - tx, - ) + ( + NotificationProtocol::new( + transport_service, + config, + std::sync::Arc::new(DefaultExecutor {}), + ), + handle, + manager, + tx, + ) } /// add new peer to `NotificationProtocol` fn add_peer() -> (PeerId, (), Receiver) { - let (_tx, rx) = channel(64); + let (_tx, rx) = channel(64); - (PeerId::random(), (), rx) + (PeerId::random(), (), rx) } diff --git a/src/protocol/notification/tests/notification.rs b/src/protocol/notification/tests/notification.rs index 4ae17a39..52b12f07 100644 --- a/src/protocol/notification/tests/notification.rs +++ b/src/protocol/notification/tests/notification.rs @@ -19,741 +19,770 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - error::Error, - mock::substream::{DummySubstream, MockSubstream}, - protocol::{ - self, - connection::ConnectionHandle, - notification::{ - negotiation::HandshakeEvent, - tests::make_notification_protocol, - types::{Direction, NotificationError, NotificationEvent}, - ConnectionState, InboundState, NotificationProtocol, OutboundState, PeerContext, - PeerState, ValidationResult, - }, - InnerTransportEvent, ProtocolCommand, - }, - substream::Substream, - transport::Endpoint, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - PeerId, + error::Error, + mock::substream::{DummySubstream, MockSubstream}, + protocol::{ + self, + connection::ConnectionHandle, + notification::{ + negotiation::HandshakeEvent, + tests::make_notification_protocol, + types::{Direction, NotificationError, NotificationEvent}, + ConnectionState, InboundState, NotificationProtocol, OutboundState, PeerContext, + PeerState, ValidationResult, + }, + InnerTransportEvent, ProtocolCommand, + }, + substream::Substream, + transport::Endpoint, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + PeerId, }; use futures::StreamExt; use multiaddr::Multiaddr; use tokio::sync::{ - mpsc::{channel, Receiver, Sender}, - oneshot, + mpsc::{channel, Receiver, Sender}, + oneshot, }; use std::{task::Poll, time::Duration}; fn next_inbound_state(state: usize) -> InboundState { - match state { - 0 => InboundState::Closed, - 1 => InboundState::ReadingHandshake, - 2 => InboundState::Validating { - inbound: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - }, - 3 => InboundState::SendingHandshake, - 4 => InboundState::Open { - inbound: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - }, - _ => panic!(), - } + match state { + 0 => InboundState::Closed, + 1 => InboundState::ReadingHandshake, + 2 => InboundState::Validating { + inbound: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + }, + 3 => InboundState::SendingHandshake, + 4 => InboundState::Open { + inbound: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + }, + _ => panic!(), + } } fn next_outbound_state(state: usize) -> OutboundState { - match state { - 0 => OutboundState::Closed, - 1 => OutboundState::OutboundInitiated { substream: SubstreamId::new() }, - 2 => OutboundState::Negotiating, - 3 => OutboundState::Open { - handshake: vec![1, 3, 3, 7], - outbound: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - }, - _ => panic!(), - } + match state { + 0 => OutboundState::Closed, + 1 => OutboundState::OutboundInitiated { + substream: SubstreamId::new(), + }, + 2 => OutboundState::Negotiating, + 3 => OutboundState::Open { + handshake: vec![1, 3, 3, 7], + outbound: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + }, + _ => panic!(), + } } #[tokio::test] async fn connection_closed_for_outbound_open_substream() { - let peer = PeerId::random(); - - for i in 0..5 { - connection_closed( - peer, - PeerState::Validating { - direction: Direction::Inbound, - protocol: ProtocolName::from("/notif/1"), - fallback: None, - outbound: OutboundState::Open { - handshake: vec![1, 2, 3, 4], - outbound: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - }, - inbound: next_inbound_state(i), - }, - Some(NotificationEvent::NotificationStreamOpenFailure { - peer, - error: NotificationError::Rejected, - }), - ) - .await; - } + let peer = PeerId::random(); + + for i in 0..5 { + connection_closed( + peer, + PeerState::Validating { + direction: Direction::Inbound, + protocol: ProtocolName::from("/notif/1"), + fallback: None, + outbound: OutboundState::Open { + handshake: vec![1, 2, 3, 4], + outbound: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + }, + inbound: next_inbound_state(i), + }, + Some(NotificationEvent::NotificationStreamOpenFailure { + peer, + error: NotificationError::Rejected, + }), + ) + .await; + } } #[tokio::test] async fn connection_closed_for_outbound_initiated_substream() { - let peer = PeerId::random(); - - for i in 0..5 { - connection_closed( - peer, - PeerState::Validating { - direction: Direction::Inbound, - protocol: ProtocolName::from("/notif/1"), - fallback: None, - outbound: OutboundState::OutboundInitiated { substream: SubstreamId::from(0usize) }, - inbound: next_inbound_state(i), - }, - Some(NotificationEvent::NotificationStreamOpenFailure { - peer, - error: NotificationError::Rejected, - }), - ) - .await; - } + let peer = PeerId::random(); + + for i in 0..5 { + connection_closed( + peer, + PeerState::Validating { + direction: Direction::Inbound, + protocol: ProtocolName::from("/notif/1"), + fallback: None, + outbound: OutboundState::OutboundInitiated { + substream: SubstreamId::from(0usize), + }, + inbound: next_inbound_state(i), + }, + Some(NotificationEvent::NotificationStreamOpenFailure { + peer, + error: NotificationError::Rejected, + }), + ) + .await; + } } #[tokio::test] async fn connection_closed_for_outbound_negotiated_substream() { - let peer = PeerId::random(); - - for i in 0..5 { - connection_closed( - peer, - PeerState::Validating { - direction: Direction::Inbound, - protocol: ProtocolName::from("/notif/1"), - fallback: None, - outbound: OutboundState::Negotiating, - inbound: next_inbound_state(i), - }, - Some(NotificationEvent::NotificationStreamOpenFailure { - peer, - error: NotificationError::Rejected, - }), - ) - .await; - } + let peer = PeerId::random(); + + for i in 0..5 { + connection_closed( + peer, + PeerState::Validating { + direction: Direction::Inbound, + protocol: ProtocolName::from("/notif/1"), + fallback: None, + outbound: OutboundState::Negotiating, + inbound: next_inbound_state(i), + }, + Some(NotificationEvent::NotificationStreamOpenFailure { + peer, + error: NotificationError::Rejected, + }), + ) + .await; + } } #[tokio::test] async fn connection_closed_for_initiated_substream() { - let peer = PeerId::random(); - - connection_closed( - peer, - PeerState::OutboundInitiated { substream: SubstreamId::new() }, - Some(NotificationEvent::NotificationStreamOpenFailure { - peer, - error: NotificationError::Rejected, - }), - ) - .await; + let peer = PeerId::random(); + + connection_closed( + peer, + PeerState::OutboundInitiated { + substream: SubstreamId::new(), + }, + Some(NotificationEvent::NotificationStreamOpenFailure { + peer, + error: NotificationError::Rejected, + }), + ) + .await; } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn connection_established_twice() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let peer = PeerId::random(); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let peer = PeerId::random(); - assert!(notif.on_connection_established(peer).await.is_ok()); - assert!(notif.on_connection_established(peer).await.is_err()); + assert!(notif.on_connection_established(peer).await.is_ok()); + assert!(notif.on_connection_established(peer).await.is_err()); } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn connection_closed_twice() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let peer = PeerId::random(); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let peer = PeerId::random(); - assert!(notif.on_connection_closed(peer).await.is_ok()); - assert!(notif.on_connection_closed(peer).await.is_err()); + assert!(notif.on_connection_closed(peer).await.is_ok()); + assert!(notif.on_connection_closed(peer).await.is_err()); } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn substream_open_failure_for_unknown_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - notif.on_substream_open_failure(SubstreamId::new(), Error::Unknown).await; + notif.on_substream_open_failure(SubstreamId::new(), Error::Unknown).await; } #[tokio::test] async fn close_substream_to_unknown_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let peer = PeerId::random(); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let peer = PeerId::random(); - assert!(!notif.peers.contains_key(&peer)); - notif.on_close_substream(peer).await; - assert!(!notif.peers.contains_key(&peer)); + assert!(!notif.peers.contains_key(&peer)); + notif.on_close_substream(peer).await; + assert!(!notif.peers.contains_key(&peer)); } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn handshake_event_unknown_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let peer = PeerId::random(); - - assert!(!notif.peers.contains_key(&peer)); - notif - .on_handshake_event( - peer, - HandshakeEvent::Negotiated { - peer, - handshake: vec![1, 3, 3, 7], - substream: Substream::new_mock( - peer, - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - direction: protocol::notification::negotiation::Direction::Inbound, - }, - ) - .await; - assert!(!notif.peers.contains_key(&peer)); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let peer = PeerId::random(); + + assert!(!notif.peers.contains_key(&peer)); + notif + .on_handshake_event( + peer, + HandshakeEvent::Negotiated { + peer, + handshake: vec![1, 3, 3, 7], + substream: Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + direction: protocol::notification::negotiation::Direction::Inbound, + }, + ) + .await; + assert!(!notif.peers.contains_key(&peer)); } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn handshake_event_invalid_state_for_outbound_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); - let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; - - notif - .on_handshake_event( - peer, - HandshakeEvent::Negotiated { - peer, - handshake: vec![1, 3, 3, 7], - substream: Substream::new_mock( - peer, - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - direction: protocol::notification::negotiation::Direction::Outbound, - }, - ) - .await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); + let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; + + notif + .on_handshake_event( + peer, + HandshakeEvent::Negotiated { + peer, + handshake: vec![1, 3, 3, 7], + substream: Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + direction: protocol::notification::negotiation::Direction::Outbound, + }, + ) + .await; } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn substream_open_failure_for_unknown_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let peer = PeerId::random(); - let substream_id = SubstreamId::from(1337usize); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let peer = PeerId::random(); + let substream_id = SubstreamId::from(1337usize); - notif.pending_outbound.insert(substream_id, peer); - notif.on_substream_open_failure(substream_id, Error::Unknown).await; + notif.pending_outbound.insert(substream_id, peer); + notif.on_substream_open_failure(substream_id, Error::Unknown).await; } #[tokio::test] async fn dial_failure_for_non_dialing_peer() { - let (mut notif, mut handle, _sender, mut tx) = make_notification_protocol(); - let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; - - // dial failure for the peer even though it's not dialing - notif.on_dial_failure(peer, Multiaddr::empty()).await; - - assert!(std::matches!( - notif.peers.get(&peer), - Some(PeerContext { state: PeerState::Closed { .. } }) - )); - futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; + let (mut notif, mut handle, _sender, mut tx) = make_notification_protocol(); + let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; + + // dial failure for the peer even though it's not dialing + notif.on_dial_failure(peer, Multiaddr::empty()).await; + + assert!(std::matches!( + notif.peers.get(&peer), + Some(PeerContext { + state: PeerState::Closed { .. } + }) + )); + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; } // inbound state is ignored async fn connection_closed(peer: PeerId, state: PeerState, event: Option) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, mut handle, _sender, _tx) = make_notification_protocol(); + let (mut notif, mut handle, _sender, _tx) = make_notification_protocol(); - notif.peers.insert(peer, PeerContext { state }); - notif.on_connection_closed(peer).await.unwrap(); + notif.peers.insert(peer, PeerContext { state }); + notif.on_connection_closed(peer).await.unwrap(); - if let Some(expected) = event { - assert_eq!(handle.next().await.unwrap(), expected); - } - assert!(!notif.peers.contains_key(&peer)) + if let Some(expected) = event { + assert_eq!(handle.next().await.unwrap(), expected); + } + assert!(!notif.peers.contains_key(&peer)) } // register new connection to `NotificationProtocol` async fn register_peer( - notif: &mut NotificationProtocol, - sender: &mut Sender, + notif: &mut NotificationProtocol, + sender: &mut Sender, ) -> (PeerId, Receiver) { - let peer = PeerId::random(); - let (conn_tx, conn_rx) = channel(64); - - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::new(), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), conn_tx), - }) - .await - .unwrap(); - - // poll the protocol to register the peer - notif.next_event().await; - - assert!(std::matches!( - notif.peers.get(&peer), - Some(PeerContext { state: PeerState::Closed { .. } }) - )); - - (peer, conn_rx) + let peer = PeerId::random(); + let (conn_tx, conn_rx) = channel(64); + + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::new(), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), conn_tx), + }) + .await + .unwrap(); + + // poll the protocol to register the peer + notif.next_event().await; + + assert!(std::matches!( + notif.peers.get(&peer), + Some(PeerContext { + state: PeerState::Closed { .. } + }) + )); + + (peer, conn_rx) } #[tokio::test] async fn open_substream_connection_closed() { - open_substream(PeerState::Closed { pending_open: None }, true).await; + open_substream(PeerState::Closed { pending_open: None }, true).await; } #[tokio::test] async fn open_substream_already_initiated() { - open_substream(PeerState::OutboundInitiated { substream: SubstreamId::new() }, false).await; + open_substream( + PeerState::OutboundInitiated { + substream: SubstreamId::new(), + }, + false, + ) + .await; } #[tokio::test] async fn open_substream_already_open() { - let (shutdown, _rx) = oneshot::channel(); - open_substream(PeerState::Open { shutdown }, false).await; + let (shutdown, _rx) = oneshot::channel(); + open_substream(PeerState::Open { shutdown }, false).await; } #[tokio::test] async fn open_substream_under_validation() { - for i in 0..5 { - for k in 0..4 { - open_substream( - PeerState::Validating { - direction: Direction::Inbound, - protocol: ProtocolName::from("/notif/1"), - fallback: None, - outbound: next_outbound_state(k), - inbound: next_inbound_state(i), - }, - false, - ) - .await; - } - } + for i in 0..5 { + for k in 0..4 { + open_substream( + PeerState::Validating { + direction: Direction::Inbound, + protocol: ProtocolName::from("/notif/1"), + fallback: None, + outbound: next_outbound_state(k), + inbound: next_inbound_state(i), + }, + false, + ) + .await; + } + } } async fn open_substream(state: PeerState, succeeds: bool) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); - let (peer, mut receiver) = register_peer(&mut notif, &mut tx).await; + let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); + let (peer, mut receiver) = register_peer(&mut notif, &mut tx).await; - let context = notif.peers.get_mut(&peer).unwrap(); - context.state = state; + let context = notif.peers.get_mut(&peer).unwrap(); + context.state = state; - notif.on_open_substream(peer).await.unwrap(); - assert!(receiver.try_recv().is_ok() == succeeds); + notif.on_open_substream(peer).await.unwrap(); + assert!(receiver.try_recv().is_ok() == succeeds); } #[tokio::test] async fn open_substream_no_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - assert!(notif.on_open_substream(PeerId::random()).await.is_err()); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + assert!(notif.on_open_substream(PeerId::random()).await.is_err()); } #[tokio::test] async fn remote_opens_multiple_inbound_substreams() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let protocol = ProtocolName::from("/notif/1"); - let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); - let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; - - // open substream, poll the result and verify that the peer is in correct state - tx.send(InnerTransportEvent::SubstreamOpened { - peer, - protocol: protocol.clone(), - fallback: None, - direction: protocol::Direction::Inbound, - substream: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - }) - .await - .unwrap(); - notif.next_event().await; - - match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - direction: Direction::Inbound, - protocol, - fallback: None, - outbound: OutboundState::Closed, - inbound: InboundState::ReadingHandshake, - }, - }) => { - assert_eq!(protocol, &ProtocolName::from("/notif/1")); - }, - state => panic!("invalid state: {state:?}"), - } - - // try to open another substream and verify it's discarded and the state is otherwise - // preserved - let mut substream = MockSubstream::new(); - substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - - tx.send(InnerTransportEvent::SubstreamOpened { - peer, - protocol: protocol.clone(), - fallback: None, - direction: protocol::Direction::Inbound, - substream: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(substream), - ), - }) - .await - .unwrap(); - notif.next_event().await; - - match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - direction: Direction::Inbound, - protocol, - fallback: None, - outbound: OutboundState::Closed, - inbound: InboundState::ReadingHandshake, - }, - }) => { - assert_eq!(protocol, &ProtocolName::from("/notif/1")); - }, - state => panic!("invalid state: {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let protocol = ProtocolName::from("/notif/1"); + let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); + let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; + + // open substream, poll the result and verify that the peer is in correct state + tx.send(InnerTransportEvent::SubstreamOpened { + peer, + protocol: protocol.clone(), + fallback: None, + direction: protocol::Direction::Inbound, + substream: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + }) + .await + .unwrap(); + notif.next_event().await; + + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + direction: Direction::Inbound, + protocol, + fallback: None, + outbound: OutboundState::Closed, + inbound: InboundState::ReadingHandshake, + }, + }) => { + assert_eq!(protocol, &ProtocolName::from("/notif/1")); + } + state => panic!("invalid state: {state:?}"), + } + + // try to open another substream and verify it's discarded and the state is otherwise + // preserved + let mut substream = MockSubstream::new(); + substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + + tx.send(InnerTransportEvent::SubstreamOpened { + peer, + protocol: protocol.clone(), + fallback: None, + direction: protocol::Direction::Inbound, + substream: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(substream), + ), + }) + .await + .unwrap(); + notif.next_event().await; + + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + direction: Direction::Inbound, + protocol, + fallback: None, + outbound: OutboundState::Closed, + inbound: InboundState::ReadingHandshake, + }, + }) => { + assert_eq!(protocol, &ProtocolName::from("/notif/1")); + } + state => panic!("invalid state: {state:?}"), + } } #[tokio::test] async fn pending_outbound_tracked_correctly() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let protocol = ProtocolName::from("/notif/1"); - let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); - let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; - - // open outbound substream - notif.on_open_substream(peer).await.unwrap(); - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::OutboundInitiated { substream } }) => { - assert_eq!(substream, &SubstreamId::new()); - }, - state => panic!("invalid state: {state:?}"), - } - - // then register inbound substream and verify that the state is changed to `Validating` - notif - .on_inbound_substream( - protocol.clone(), - None, - peer, - Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - ) - .await - .unwrap(); - - match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - direction: Direction::Outbound, - outbound: OutboundState::OutboundInitiated { .. }, - inbound: InboundState::ReadingHandshake, - .. - }, - }) => {}, - state => panic!("invalid state: {state:?}"), - } - - // then negotiation event for the inbound handshake - notif - .on_handshake_event( - peer, - HandshakeEvent::Negotiated { - peer, - handshake: vec![1, 3, 3, 7], - substream: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - direction: protocol::notification::negotiation::Direction::Inbound, - }, - ) - .await; - - match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - direction: Direction::Outbound, - outbound: OutboundState::OutboundInitiated { .. }, - inbound: InboundState::Validating { .. }, - .. - }, - }) => {}, - state => panic!("invalid state: {state:?}"), - } - - // then reject the inbound peer even though an outbound substream was already established - notif.on_validation_result(peer, ValidationResult::Reject).await.unwrap(); - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open } }) => { - assert_eq!(pending_open, &Some(SubstreamId::new())); - }, - state => panic!("invalid state: {state:?}"), - } - - // finally the outbound substream registers, verify that `pending_open` is set to `None` - notif - .on_outbound_substream( - protocol, - None, - peer, - SubstreamId::new(), - Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - ) - .await - .unwrap(); - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open } }) => { - assert!(pending_open.is_none()); - }, - state => panic!("invalid state: {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let protocol = ProtocolName::from("/notif/1"); + let (mut notif, _handle, _sender, mut tx) = make_notification_protocol(); + let (peer, _receiver) = register_peer(&mut notif, &mut tx).await; + + // open outbound substream + notif.on_open_substream(peer).await.unwrap(); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::OutboundInitiated { substream }, + }) => { + assert_eq!(substream, &SubstreamId::new()); + } + state => panic!("invalid state: {state:?}"), + } + + // then register inbound substream and verify that the state is changed to `Validating` + notif + .on_inbound_substream( + protocol.clone(), + None, + peer, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + ) + .await + .unwrap(); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + direction: Direction::Outbound, + outbound: OutboundState::OutboundInitiated { .. }, + inbound: InboundState::ReadingHandshake, + .. + }, + }) => {} + state => panic!("invalid state: {state:?}"), + } + + // then negotiation event for the inbound handshake + notif + .on_handshake_event( + peer, + HandshakeEvent::Negotiated { + peer, + handshake: vec![1, 3, 3, 7], + substream: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + direction: protocol::notification::negotiation::Direction::Inbound, + }, + ) + .await; + + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + direction: Direction::Outbound, + outbound: OutboundState::OutboundInitiated { .. }, + inbound: InboundState::Validating { .. }, + .. + }, + }) => {} + state => panic!("invalid state: {state:?}"), + } + + // then reject the inbound peer even though an outbound substream was already established + notif.on_validation_result(peer, ValidationResult::Reject).await.unwrap(); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::Closed { pending_open }, + }) => { + assert_eq!(pending_open, &Some(SubstreamId::new())); + } + state => panic!("invalid state: {state:?}"), + } + + // finally the outbound substream registers, verify that `pending_open` is set to `None` + notif + .on_outbound_substream( + protocol, + None, + peer, + SubstreamId::new(), + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + ) + .await + .unwrap(); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::Closed { pending_open }, + }) => { + assert!(pending_open.is_none()); + } + state => panic!("invalid state: {state:?}"), + } } #[tokio::test] async fn inbound_accepted_outbound_fails_to_open() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let protocol = ProtocolName::from("/notif/1"); - let (mut notif, mut handle, sender, mut tx) = make_notification_protocol(); - let (peer, receiver) = register_peer(&mut notif, &mut tx).await; - - // register inbound substream and verify that the state is `Validating` - notif - .on_inbound_substream( - protocol.clone(), - None, - peer, - Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - ) - .await - .unwrap(); - - match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - direction: Direction::Inbound, - outbound: OutboundState::Closed { .. }, - inbound: InboundState::ReadingHandshake, - .. - }, - }) => {}, - state => panic!("invalid state: {state:?}"), - } - - // then negotiation event for the inbound handshake - notif - .on_handshake_event( - peer, - HandshakeEvent::Negotiated { - peer, - handshake: vec![1, 3, 3, 7], - substream: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - direction: protocol::notification::negotiation::Direction::Inbound, - }, - ) - .await; - - match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - direction: Direction::Inbound, - outbound: OutboundState::Closed { .. }, - inbound: InboundState::Validating { .. }, - .. - }, - }) => {}, - state => panic!("invalid state: {state:?}"), - } - - // discard the validation event - assert!(tokio::time::timeout(Duration::from_secs(5), handle.next()).await.is_ok()); - - // before the validation event is registered, close the connection - drop(sender); - drop(receiver); - drop(tx); - - // then reject the inbound peer even though an outbound substream was already established - assert!(notif.on_validation_result(peer, ValidationResult::Accept).await.is_err()); - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open } }) => { - assert!(pending_open.is_none()); - }, - state => panic!("invalid state: {state:?}"), - } - - // verify that the user is not reported anything - match tokio::time::timeout(Duration::from_secs(1), handle.next()).await { - Err(_) => panic!("unexpected timeout"), - Ok(Some(NotificationEvent::NotificationStreamOpenFailure { peer: event_peer, error })) => { - assert_eq!(peer, event_peer); - assert_eq!(error, NotificationError::Rejected) - }, - _ => panic!("invalid event"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let protocol = ProtocolName::from("/notif/1"); + let (mut notif, mut handle, sender, mut tx) = make_notification_protocol(); + let (peer, receiver) = register_peer(&mut notif, &mut tx).await; + + // register inbound substream and verify that the state is `Validating` + notif + .on_inbound_substream( + protocol.clone(), + None, + peer, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + ) + .await + .unwrap(); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + direction: Direction::Inbound, + outbound: OutboundState::Closed { .. }, + inbound: InboundState::ReadingHandshake, + .. + }, + }) => {} + state => panic!("invalid state: {state:?}"), + } + + // then negotiation event for the inbound handshake + notif + .on_handshake_event( + peer, + HandshakeEvent::Negotiated { + peer, + handshake: vec![1, 3, 3, 7], + substream: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + direction: protocol::notification::negotiation::Direction::Inbound, + }, + ) + .await; + + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + direction: Direction::Inbound, + outbound: OutboundState::Closed { .. }, + inbound: InboundState::Validating { .. }, + .. + }, + }) => {} + state => panic!("invalid state: {state:?}"), + } + + // discard the validation event + assert!(tokio::time::timeout(Duration::from_secs(5), handle.next()).await.is_ok()); + + // before the validation event is registered, close the connection + drop(sender); + drop(receiver); + drop(tx); + + // then reject the inbound peer even though an outbound substream was already established + assert!(notif.on_validation_result(peer, ValidationResult::Accept).await.is_err()); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::Closed { pending_open }, + }) => { + assert!(pending_open.is_none()); + } + state => panic!("invalid state: {state:?}"), + } + + // verify that the user is not reported anything + match tokio::time::timeout(Duration::from_secs(1), handle.next()).await { + Err(_) => panic!("unexpected timeout"), + Ok(Some(NotificationEvent::NotificationStreamOpenFailure { + peer: event_peer, + error, + })) => { + assert_eq!(peer, event_peer); + assert_eq!(error, NotificationError::Rejected) + } + _ => panic!("invalid event"), + } } #[tokio::test] async fn open_substream_on_closed_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, sender, mut tx) = make_notification_protocol(); - let (peer, receiver) = register_peer(&mut notif, &mut tx).await; - - // before processing the open substream event, close the connection - drop(sender); - drop(receiver); - drop(tx); - - // open outbound substream - notif.on_open_substream(peer).await.unwrap(); - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open: None } }) => {}, - state => panic!("invalid state: {state:?}"), - } - - match tokio::time::timeout(Duration::from_secs(5), handle.next()) - .await - .expect("operation to succeed") - { - Some(NotificationEvent::NotificationStreamOpenFailure { error, .. }) => { - assert_eq!(error, NotificationError::NoConnection); - }, - event => panic!("invalid event received: {event:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, sender, mut tx) = make_notification_protocol(); + let (peer, receiver) = register_peer(&mut notif, &mut tx).await; + + // before processing the open substream event, close the connection + drop(sender); + drop(receiver); + drop(tx); + + // open outbound substream + notif.on_open_substream(peer).await.unwrap(); + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::Closed { pending_open: None }, + }) => {} + state => panic!("invalid state: {state:?}"), + } + + match tokio::time::timeout(Duration::from_secs(5), handle.next()) + .await + .expect("operation to succeed") + { + Some(NotificationEvent::NotificationStreamOpenFailure { error, .. }) => { + assert_eq!(error, NotificationError::NoConnection); + } + event => panic!("invalid event received: {event:?}"), + } } // `NotificationHandle` may have an inconsistent view of the peer state and connection to peer may @@ -764,67 +793,69 @@ async fn open_substream_on_closed_connection() { // verify that `NotificationProtocol` ignores stale disconnection requests #[tokio::test] async fn close_already_closed_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _, mut tx) = make_notification_protocol(); - let (peer, _) = register_peer(&mut notif, &mut tx).await; - - notif.peers.insert( - peer, - PeerContext { - state: PeerState::Validating { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - outbound: OutboundState::Open { - handshake: vec![1, 2, 3, 4], - outbound: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - }, - inbound: InboundState::SendingHandshake, - }, - }, - ); - notif - .on_handshake_event( - peer, - HandshakeEvent::Negotiated { - peer, - handshake: vec![1], - substream: Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - direction: protocol::notification::negotiation::Direction::Inbound, - }, - ) - .await; - - match handle.next().await { - Some(NotificationEvent::NotificationStreamOpened { .. }) => {}, - _ => panic!("invalid event received"), - } - - // close the substream but don't poll the `NotificationHandle` - notif.shutdown_tx.send(peer).await.unwrap(); - - // close the connection using the handle - handle.close_substream(peer).await; - - // process the events - notif.next_event().await; - notif.next_event().await; - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open: None } }) => {}, - state => panic!("invalid state: {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _, mut tx) = make_notification_protocol(); + let (peer, _) = register_peer(&mut notif, &mut tx).await; + + notif.peers.insert( + peer, + PeerContext { + state: PeerState::Validating { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + outbound: OutboundState::Open { + handshake: vec![1, 2, 3, 4], + outbound: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + }, + inbound: InboundState::SendingHandshake, + }, + }, + ); + notif + .on_handshake_event( + peer, + HandshakeEvent::Negotiated { + peer, + handshake: vec![1], + substream: Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + direction: protocol::notification::negotiation::Direction::Inbound, + }, + ) + .await; + + match handle.next().await { + Some(NotificationEvent::NotificationStreamOpened { .. }) => {} + _ => panic!("invalid event received"), + } + + // close the substream but don't poll the `NotificationHandle` + notif.shutdown_tx.send(peer).await.unwrap(); + + // close the connection using the handle + handle.close_substream(peer).await; + + // process the events + notif.next_event().await; + notif.next_event().await; + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::Closed { pending_open: None }, + }) => {} + state => panic!("invalid state: {state:?}"), + } } /// Notification state was not reset correctly if the outbound substream failed to open after @@ -832,64 +863,69 @@ async fn close_already_closed_connection() { /// twice, once when the failure occurred and again when the connection was closed. #[tokio::test] async fn open_failure_reported_once() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _, mut tx) = make_notification_protocol(); - let (peer, _) = register_peer(&mut notif, &mut tx).await; - - // move `peer` to state where the inbound substream has been negotiated - // and the local node has initiated an outbound substream - notif.peers.insert( - peer, - PeerContext { - state: PeerState::Validating { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - outbound: OutboundState::OutboundInitiated { - substream: SubstreamId::from(1337usize), - }, - inbound: InboundState::Open { - inbound: Substream::new_mock( - peer, - SubstreamId::from(0usize), - Box::new(DummySubstream::new()), - ), - }, - }, - }, - ); - notif.pending_outbound.insert(SubstreamId::from(1337usize), peer); - - notif - .on_substream_open_failure(SubstreamId::from(1337usize), Error::Unknown) - .await; - - match handle.next().await { - Some(NotificationEvent::NotificationStreamOpenFailure { peer: failed_peer, error }) => { - assert_eq!(failed_peer, peer); - assert_eq!(error, NotificationError::Rejected); - }, - _ => panic!("invalid event received"), - } - - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open } }) => { - assert_eq!(pending_open, &Some(SubstreamId::from(1337usize))); - }, - state => panic!("invalid state for peer: {state:?}"), - } - - // connection to `peer` is closed - notif.on_connection_closed(peer).await.unwrap(); - - futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - result => panic!("didn't expect event from channel, got {result:?}"), - }) - .await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _, mut tx) = make_notification_protocol(); + let (peer, _) = register_peer(&mut notif, &mut tx).await; + + // move `peer` to state where the inbound substream has been negotiated + // and the local node has initiated an outbound substream + notif.peers.insert( + peer, + PeerContext { + state: PeerState::Validating { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + outbound: OutboundState::OutboundInitiated { + substream: SubstreamId::from(1337usize), + }, + inbound: InboundState::Open { + inbound: Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + }, + }, + }, + ); + notif.pending_outbound.insert(SubstreamId::from(1337usize), peer); + + notif + .on_substream_open_failure(SubstreamId::from(1337usize), Error::Unknown) + .await; + + match handle.next().await { + Some(NotificationEvent::NotificationStreamOpenFailure { + peer: failed_peer, + error, + }) => { + assert_eq!(failed_peer, peer); + assert_eq!(error, NotificationError::Rejected); + } + _ => panic!("invalid event received"), + } + + match notif.peers.get(&peer) { + Some(PeerContext { + state: PeerState::Closed { pending_open }, + }) => { + assert_eq!(pending_open, &Some(SubstreamId::from(1337usize))); + } + state => panic!("invalid state for peer: {state:?}"), + } + + // connection to `peer` is closed + notif.on_connection_closed(peer).await.unwrap(); + + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + result => panic!("didn't expect event from channel, got {result:?}"), + }) + .await; } // inboud substrem was received and it was sent to user for validation @@ -900,67 +936,70 @@ async fn open_failure_reported_once() { // verify that the new substream is rejected and that the peer state is set to `ValidationPending` #[tokio::test] async fn second_inbound_substream_rejected() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _, mut tx) = make_notification_protocol(); - let (peer, _) = register_peer(&mut notif, &mut tx).await; - - // move peer state to `Validating` - let mut substream1 = MockSubstream::new(); - substream1.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - - notif.peers.insert( - peer, - PeerContext { - state: PeerState::Validating { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - outbound: OutboundState::Closed, - inbound: InboundState::Validating { - inbound: Substream::new_mock( - peer, - SubstreamId::from(0usize), - Box::new(substream1), - ), - }, - }, - }, - ); - - // open a new inbound substream because validation took so long that `peer` decided - // to open a new substream - let mut substream2 = MockSubstream::new(); - substream2.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - notif - .on_inbound_substream( - ProtocolName::from("/notif/1"), - None, - peer, - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream2)), - ) - .await - .unwrap(); - - // verify that peer is moved to `ValidationPending` - match notif.peers.get(&peer) { - Some(PeerContext { - state: PeerState::ValidationPending { state: ConnectionState::Open }, - }) => {}, - state => panic!("invalid state for peer: {state:?}"), - } - - // user decide to reject the substream, verify that nothing is received over the event handle - notif.on_validation_result(peer, ValidationResult::Reject).await.unwrap(); - - notif.on_connection_closed(peer).await.unwrap(); - futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - result => panic!("didn't expect event from channel, got {result:?}"), - }) - .await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _, mut tx) = make_notification_protocol(); + let (peer, _) = register_peer(&mut notif, &mut tx).await; + + // move peer state to `Validating` + let mut substream1 = MockSubstream::new(); + substream1.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + + notif.peers.insert( + peer, + PeerContext { + state: PeerState::Validating { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + outbound: OutboundState::Closed, + inbound: InboundState::Validating { + inbound: Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(substream1), + ), + }, + }, + }, + ); + + // open a new inbound substream because validation took so long that `peer` decided + // to open a new substream + let mut substream2 = MockSubstream::new(); + substream2.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + notif + .on_inbound_substream( + ProtocolName::from("/notif/1"), + None, + peer, + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream2)), + ) + .await + .unwrap(); + + // verify that peer is moved to `ValidationPending` + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::ValidationPending { + state: ConnectionState::Open, + }, + }) => {} + state => panic!("invalid state for peer: {state:?}"), + } + + // user decide to reject the substream, verify that nothing is received over the event handle + notif.on_validation_result(peer, ValidationResult::Reject).await.unwrap(); + + notif.on_connection_closed(peer).await.unwrap(); + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + result => panic!("didn't expect event from channel, got {result:?}"), + }) + .await; } // remote opened a substream, it was accepted by the local node and local node opened an outbound @@ -971,92 +1010,102 @@ async fn second_inbound_substream_rejected() { // connection is still pending #[tokio::test] async fn second_inbound_substream_opened_while_outbound_substream_was_opening() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _zz, mut tx) = make_notification_protocol(); - let (peer, _zz) = register_peer(&mut notif, &mut tx).await; - - // move peer state to `Validating` - let mut substream1 = MockSubstream::new(); - substream1 - .expect_poll_ready() - .times(1) - .return_once(|_| Poll::Ready(Err(Error::Unknown))); - - notif.peers.insert( - peer, - PeerContext { - state: PeerState::Validating { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - outbound: OutboundState::Closed, - inbound: InboundState::Validating { - inbound: Substream::new_mock( - peer, - SubstreamId::from(0usize), - Box::new(substream1), - ), - }, - }, - }, - ); - - // accept the inbound substream which is now closed - notif.on_validation_result(peer, ValidationResult::Accept).await.unwrap(); - - // verify that peer is sending handshake and that outbound substream is opening - let substream_id = match notif.peers.get(&peer) { - Some(PeerContext { - state: - PeerState::Validating { - fallback: None, - direction: Direction::Inbound, - outbound: OutboundState::OutboundInitiated { substream }, - inbound: InboundState::SendingHandshake, - .. - }, - }) => *substream, - state => panic!("invalid state for peer: {state:?}"), - }; - - // poll the protocol and send handshake over the inbound substream - notif.next_event().await; - - // verify that peer is closed - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open: Some(pending_open) } }) => { - assert_eq!(substream_id, *pending_open); - }, - state => panic!("invalid state for peer: {state:?}"), - } - - match handle.next().await { - Some(NotificationEvent::NotificationStreamOpenFailure { .. }) => {}, - _ => panic!("invalid event received"), - } - - // remote open second inbound substream - let mut substream2 = MockSubstream::new(); - substream2.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - - notif - .on_inbound_substream( - ProtocolName::from("/notif/1"), - None, - peer, - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream2)), - ) - .await - .unwrap(); - - // verify that peer is still closed - match notif.peers.get(&peer) { - Some(PeerContext { state: PeerState::Closed { pending_open: Some(pending_open) } }) => { - assert_eq!(substream_id, *pending_open); - }, - state => panic!("invalid state for peer: {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _zz, mut tx) = make_notification_protocol(); + let (peer, _zz) = register_peer(&mut notif, &mut tx).await; + + // move peer state to `Validating` + let mut substream1 = MockSubstream::new(); + substream1 + .expect_poll_ready() + .times(1) + .return_once(|_| Poll::Ready(Err(Error::Unknown))); + + notif.peers.insert( + peer, + PeerContext { + state: PeerState::Validating { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + outbound: OutboundState::Closed, + inbound: InboundState::Validating { + inbound: Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(substream1), + ), + }, + }, + }, + ); + + // accept the inbound substream which is now closed + notif.on_validation_result(peer, ValidationResult::Accept).await.unwrap(); + + // verify that peer is sending handshake and that outbound substream is opening + let substream_id = match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Validating { + fallback: None, + direction: Direction::Inbound, + outbound: OutboundState::OutboundInitiated { substream }, + inbound: InboundState::SendingHandshake, + .. + }, + }) => *substream, + state => panic!("invalid state for peer: {state:?}"), + }; + + // poll the protocol and send handshake over the inbound substream + notif.next_event().await; + + // verify that peer is closed + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Closed { + pending_open: Some(pending_open), + }, + }) => { + assert_eq!(substream_id, *pending_open); + } + state => panic!("invalid state for peer: {state:?}"), + } + + match handle.next().await { + Some(NotificationEvent::NotificationStreamOpenFailure { .. }) => {} + _ => panic!("invalid event received"), + } + + // remote open second inbound substream + let mut substream2 = MockSubstream::new(); + substream2.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + + notif + .on_inbound_substream( + ProtocolName::from("/notif/1"), + None, + peer, + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream2)), + ) + .await + .unwrap(); + + // verify that peer is still closed + match notif.peers.get(&peer) { + Some(PeerContext { + state: + PeerState::Closed { + pending_open: Some(pending_open), + }, + }) => { + assert_eq!(substream_id, *pending_open); + } + state => panic!("invalid state for peer: {state:?}"), + } } diff --git a/src/protocol/notification/tests/substream_validation.rs b/src/protocol/notification/tests/substream_validation.rs index cfd7e5f0..cf2d6bb8 100644 --- a/src/protocol/notification/tests/substream_validation.rs +++ b/src/protocol/notification/tests/substream_validation.rs @@ -19,22 +19,22 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - error::{Error, SubstreamError}, - mock::substream::MockSubstream, - protocol::{ - connection::ConnectionHandle, - notification::{ - negotiation::HandshakeEvent, - tests::{add_peer, make_notification_protocol}, - types::{Direction, NotificationEvent, ValidationResult}, - InboundState, OutboundState, PeerContext, PeerState, - }, - InnerTransportEvent, ProtocolCommand, - }, - substream::Substream, - transport::Endpoint, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - PeerId, + error::{Error, SubstreamError}, + mock::substream::MockSubstream, + protocol::{ + connection::ConnectionHandle, + notification::{ + negotiation::HandshakeEvent, + tests::{add_peer, make_notification_protocol}, + types::{Direction, NotificationEvent, ValidationResult}, + InboundState, OutboundState, PeerContext, PeerState, + }, + InnerTransportEvent, ProtocolCommand, + }, + substream::Substream, + transport::Endpoint, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + PeerId, }; use bytes::BytesMut; @@ -46,390 +46,419 @@ use std::task::Poll; #[tokio::test] async fn non_existent_peer() { - let (mut notif, _handle, _sender, _) = make_notification_protocol(); + let (mut notif, _handle, _sender, _) = make_notification_protocol(); - if let Err(err) = notif.on_validation_result(PeerId::random(), ValidationResult::Accept).await { - assert!(std::matches!(err, Error::PeerDoesntExist(_))); - } + if let Err(err) = notif.on_validation_result(PeerId::random(), ValidationResult::Accept).await { + assert!(std::matches!(err, Error::PeerDoesntExist(_))); + } } #[tokio::test] async fn substream_accepted() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _sender, tx) = make_notification_protocol(); - let (peer, _service, _receiver) = add_peer(); - let handshake = BytesMut::from(&b"hello"[..]); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream.expect_start_send().times(1).return_once(|_| Ok(())); - substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(()))); - - let (proto_tx, mut proto_rx) = channel(256); - tx.send(InnerTransportEvent::ConnectionEstablished { - peer, - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), proto_tx.clone()), - connection: ConnectionId::from(0usize), - }) - .await - .unwrap(); - - // connect peer and verify it's in closed state - notif.next_event().await; - - match notif.peers.get(&peer).unwrap().state { - PeerState::Closed { .. } => {}, - _ => panic!("invalid state for peer"), - } - - // open inbound substream and verify that peer state has changed to `Validating` - notif - .on_inbound_substream( - ProtocolName::from("/notif/1"), - None, - peer, - Substream::new_mock(PeerId::random(), SubstreamId::from(0usize), Box::new(substream)), - ) - .await - .unwrap(); - - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Validating { - direction: Direction::Inbound, - protocol: _, - fallback: None, - inbound: InboundState::ReadingHandshake, - outbound: OutboundState::Closed, - } => {}, - state => panic!("invalid state for peer: {state:?}"), - } - - // get negotiation event - let (peer, event) = notif.negotiation.next().await.unwrap(); - notif.on_handshake_event(peer, event).await; - - // user protocol receives the protocol accepts it - assert_eq!( - handle.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer, - handshake: handshake.into() - }, - ); - notif.on_validation_result(peer, ValidationResult::Accept).await.unwrap(); - - // poll negotiation to finish the handshake - let (peer, event) = notif.negotiation.next().await.unwrap(); - notif.on_handshake_event(peer, event).await; - - // protocol asks for outbound substream to be opened and its state is changed accordingly - let ProtocolCommand::OpenSubstream { protocol, substream_id, .. } = - proto_rx.recv().await.unwrap() - else { - panic!("invalid commnd received"); - }; - assert_eq!(protocol, ProtocolName::from("/notif/1")); - assert_eq!(substream_id, SubstreamId::from(0usize)); - - let expected = SubstreamId::from(0usize); - - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Validating { - direction: Direction::Inbound, - protocol: _, - fallback: None, - inbound: InboundState::Open { .. }, - outbound: OutboundState::OutboundInitiated { substream }, - } => { - assert_eq!(substream, &expected); - }, - state => panic!("invalid state for peer: {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _sender, tx) = make_notification_protocol(); + let (peer, _service, _receiver) = add_peer(); + let handshake = BytesMut::from(&b"hello"[..]); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream.expect_start_send().times(1).return_once(|_| Ok(())); + substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(()))); + + let (proto_tx, mut proto_rx) = channel(256); + tx.send(InnerTransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), proto_tx.clone()), + connection: ConnectionId::from(0usize), + }) + .await + .unwrap(); + + // connect peer and verify it's in closed state + notif.next_event().await; + + match notif.peers.get(&peer).unwrap().state { + PeerState::Closed { .. } => {} + _ => panic!("invalid state for peer"), + } + + // open inbound substream and verify that peer state has changed to `Validating` + notif + .on_inbound_substream( + ProtocolName::from("/notif/1"), + None, + peer, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(substream), + ), + ) + .await + .unwrap(); + + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Validating { + direction: Direction::Inbound, + protocol: _, + fallback: None, + inbound: InboundState::ReadingHandshake, + outbound: OutboundState::Closed, + } => {} + state => panic!("invalid state for peer: {state:?}"), + } + + // get negotiation event + let (peer, event) = notif.negotiation.next().await.unwrap(); + notif.on_handshake_event(peer, event).await; + + // user protocol receives the protocol accepts it + assert_eq!( + handle.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer, + handshake: handshake.into() + }, + ); + notif.on_validation_result(peer, ValidationResult::Accept).await.unwrap(); + + // poll negotiation to finish the handshake + let (peer, event) = notif.negotiation.next().await.unwrap(); + notif.on_handshake_event(peer, event).await; + + // protocol asks for outbound substream to be opened and its state is changed accordingly + let ProtocolCommand::OpenSubstream { + protocol, + substream_id, + .. + } = proto_rx.recv().await.unwrap() + else { + panic!("invalid commnd received"); + }; + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(substream_id, SubstreamId::from(0usize)); + + let expected = SubstreamId::from(0usize); + + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Validating { + direction: Direction::Inbound, + protocol: _, + fallback: None, + inbound: InboundState::Open { .. }, + outbound: OutboundState::OutboundInitiated { substream }, + } => { + assert_eq!(substream, &expected); + } + state => panic!("invalid state for peer: {state:?}"), + } } #[tokio::test] async fn substream_rejected() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _sender, _tx) = make_notification_protocol(); - let (peer, _service, mut receiver) = add_peer(); - let handshake = BytesMut::from(&b"hello"[..]); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - - // connect peer and verify it's in closed state - notif.on_connection_established(peer).await.unwrap(); - - match notif.peers.get(&peer).unwrap().state { - PeerState::Closed { .. } => {}, - _ => panic!("invalid state for peer"), - } - - // open inbound substream and verify that peer state has changed to `Validating` - notif - .on_inbound_substream( - ProtocolName::from("/notif/1"), - None, - peer, - Substream::new_mock(PeerId::random(), SubstreamId::from(0usize), Box::new(substream)), - ) - .await - .unwrap(); - - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Validating { - direction: Direction::Inbound, - protocol: _, - fallback: None, - inbound: InboundState::ReadingHandshake, - outbound: OutboundState::Closed, - } => {}, - state => panic!("invalid state for peer: {state:?}"), - } - - // get negotiation event - let (peer, event) = notif.negotiation.next().await.unwrap(); - notif.on_handshake_event(peer, event).await; - - // user protocol receives the protocol accepts it - assert_eq!( - handle.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer, - handshake: handshake.into() - }, - ); - notif.on_validation_result(peer, ValidationResult::Reject).await.unwrap(); - - // substream is rejected so no outbound substraem is opened and peer is converted to closed - // state - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Closed { .. } => {}, - state => panic!("invalid state for peer: {state:?}"), - } - - assert!(receiver.try_recv().is_err()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _sender, _tx) = make_notification_protocol(); + let (peer, _service, mut receiver) = add_peer(); + let handshake = BytesMut::from(&b"hello"[..]); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + + // connect peer and verify it's in closed state + notif.on_connection_established(peer).await.unwrap(); + + match notif.peers.get(&peer).unwrap().state { + PeerState::Closed { .. } => {} + _ => panic!("invalid state for peer"), + } + + // open inbound substream and verify that peer state has changed to `Validating` + notif + .on_inbound_substream( + ProtocolName::from("/notif/1"), + None, + peer, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(substream), + ), + ) + .await + .unwrap(); + + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Validating { + direction: Direction::Inbound, + protocol: _, + fallback: None, + inbound: InboundState::ReadingHandshake, + outbound: OutboundState::Closed, + } => {} + state => panic!("invalid state for peer: {state:?}"), + } + + // get negotiation event + let (peer, event) = notif.negotiation.next().await.unwrap(); + notif.on_handshake_event(peer, event).await; + + // user protocol receives the protocol accepts it + assert_eq!( + handle.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer, + handshake: handshake.into() + }, + ); + notif.on_validation_result(peer, ValidationResult::Reject).await.unwrap(); + + // substream is rejected so no outbound substraem is opened and peer is converted to closed + // state + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Closed { .. } => {} + state => panic!("invalid state for peer: {state:?}"), + } + + assert!(receiver.try_recv().is_err()); } #[tokio::test] async fn accept_fails_due_to_closed_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _sender, tx) = make_notification_protocol(); - let (peer, _service, _receiver) = add_peer(); - let handshake = BytesMut::from(&b"hello"[..]); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream - .expect_poll_ready() - .times(1) - .return_once(|_| Poll::Ready(Err(Error::SubstreamError(SubstreamError::ConnectionClosed)))); - - let (proto_tx, _proto_rx) = channel(256); - tx.send(InnerTransportEvent::ConnectionEstablished { - peer, - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), proto_tx), - connection: ConnectionId::from(0usize), - }) - .await - .unwrap(); - - // connect peer and verify it's in closed state - notif.next_event().await; - - match notif.peers.get(&peer).unwrap().state { - PeerState::Closed { .. } => {}, - _ => panic!("invalid state for peer"), - } - - // open inbound substream and verify that peer state has changed to `InboundOpen` - notif - .on_inbound_substream( - ProtocolName::from("/notif/1"), - None, - peer, - Substream::new_mock(PeerId::random(), SubstreamId::from(0usize), Box::new(substream)), - ) - .await - .unwrap(); - - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Validating { - direction: Direction::Inbound, - protocol: _, - fallback: None, - inbound: InboundState::ReadingHandshake, - outbound: OutboundState::Closed, - } => {}, - state => panic!("invalid state for peer: {state:?}"), - } - - // get negotiation event - let (peer, event) = notif.negotiation.next().await.unwrap(); - notif.on_handshake_event(peer, event).await; - - // user protocol receives the protocol accepts it - assert_eq!( - handle.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer, - handshake: handshake.into() - }, - ); - - notif.on_validation_result(peer, ValidationResult::Accept).await.unwrap(); - - // get negotiation event - let (event_peer, event) = notif.negotiation.next().await.unwrap(); - match &event { - HandshakeEvent::NegotiationError { peer, .. } => { - assert_eq!(*peer, event_peer); - }, - event => panic!("invalid event for peer: {event:?}"), - } - notif.on_handshake_event(peer, event).await; - - // TODO: check state + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _sender, tx) = make_notification_protocol(); + let (peer, _service, _receiver) = add_peer(); + let handshake = BytesMut::from(&b"hello"[..]); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream + .expect_poll_ready() + .times(1) + .return_once(|_| Poll::Ready(Err(Error::SubstreamError(SubstreamError::ConnectionClosed)))); + + let (proto_tx, _proto_rx) = channel(256); + tx.send(InnerTransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), proto_tx), + connection: ConnectionId::from(0usize), + }) + .await + .unwrap(); + + // connect peer and verify it's in closed state + notif.next_event().await; + + match notif.peers.get(&peer).unwrap().state { + PeerState::Closed { .. } => {} + _ => panic!("invalid state for peer"), + } + + // open inbound substream and verify that peer state has changed to `InboundOpen` + notif + .on_inbound_substream( + ProtocolName::from("/notif/1"), + None, + peer, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(substream), + ), + ) + .await + .unwrap(); + + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Validating { + direction: Direction::Inbound, + protocol: _, + fallback: None, + inbound: InboundState::ReadingHandshake, + outbound: OutboundState::Closed, + } => {} + state => panic!("invalid state for peer: {state:?}"), + } + + // get negotiation event + let (peer, event) = notif.negotiation.next().await.unwrap(); + notif.on_handshake_event(peer, event).await; + + // user protocol receives the protocol accepts it + assert_eq!( + handle.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer, + handshake: handshake.into() + }, + ); + + notif.on_validation_result(peer, ValidationResult::Accept).await.unwrap(); + + // get negotiation event + let (event_peer, event) = notif.negotiation.next().await.unwrap(); + match &event { + HandshakeEvent::NegotiationError { peer, .. } => { + assert_eq!(*peer, event_peer); + } + event => panic!("invalid event for peer: {event:?}"), + } + notif.on_handshake_event(peer, event).await; + + // TODO: check state } #[tokio::test] async fn accept_fails_due_to_closed_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut notif, mut handle, _sender, tx) = make_notification_protocol(); - let (peer, _service, _receiver) = add_peer(); - let handshake = BytesMut::from(&b"hello"[..]); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); - - let (proto_tx, proto_rx) = channel(256); - tx.send(InnerTransportEvent::ConnectionEstablished { - peer, - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), proto_tx), - connection: ConnectionId::from(0usize), - }) - .await - .unwrap(); - - // connect peer and verify it's in closed state - notif.next_event().await; - - match notif.peers.get(&peer).unwrap().state { - PeerState::Closed { .. } => {}, - _ => panic!("invalid state for peer"), - } - - // open inbound substream and verify that peer state has changed to `InboundOpen` - notif - .on_inbound_substream( - ProtocolName::from("/notif/1"), - None, - peer, - Substream::new_mock(PeerId::random(), SubstreamId::from(0usize), Box::new(substream)), - ) - .await - .unwrap(); - - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Validating { - direction: Direction::Inbound, - protocol: _, - fallback: None, - inbound: InboundState::ReadingHandshake, - outbound: OutboundState::Closed, - } => {}, - state => panic!("invalid state for peer: {state:?}"), - } - - // get negotiation event - let (peer, event) = notif.negotiation.next().await.unwrap(); - notif.on_handshake_event(peer, event).await; - - // user protocol receives the protocol accepts it - assert_eq!( - handle.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer, - handshake: handshake.into() - }, - ); - - // drop the connection and verify that the protocol doesn't make any outbound substream - // requests and instead marks the connection as closed - drop(proto_rx); - - assert!(notif.on_validation_result(peer, ValidationResult::Accept).await.is_err()); - - match ¬if.peers.get(&peer).unwrap().state { - PeerState::Closed { .. } => {}, - state => panic!("invalid state for peer: {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut notif, mut handle, _sender, tx) = make_notification_protocol(); + let (peer, _service, _receiver) = add_peer(); + let handshake = BytesMut::from(&b"hello"[..]); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(()))); + + let (proto_tx, proto_rx) = channel(256); + tx.send(InnerTransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), proto_tx), + connection: ConnectionId::from(0usize), + }) + .await + .unwrap(); + + // connect peer and verify it's in closed state + notif.next_event().await; + + match notif.peers.get(&peer).unwrap().state { + PeerState::Closed { .. } => {} + _ => panic!("invalid state for peer"), + } + + // open inbound substream and verify that peer state has changed to `InboundOpen` + notif + .on_inbound_substream( + ProtocolName::from("/notif/1"), + None, + peer, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(substream), + ), + ) + .await + .unwrap(); + + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Validating { + direction: Direction::Inbound, + protocol: _, + fallback: None, + inbound: InboundState::ReadingHandshake, + outbound: OutboundState::Closed, + } => {} + state => panic!("invalid state for peer: {state:?}"), + } + + // get negotiation event + let (peer, event) = notif.negotiation.next().await.unwrap(); + notif.on_handshake_event(peer, event).await; + + // user protocol receives the protocol accepts it + assert_eq!( + handle.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer, + handshake: handshake.into() + }, + ); + + // drop the connection and verify that the protocol doesn't make any outbound substream + // requests and instead marks the connection as closed + drop(proto_rx); + + assert!(notif.on_validation_result(peer, ValidationResult::Accept).await.is_err()); + + match ¬if.peers.get(&peer).unwrap().state { + PeerState::Closed { .. } => {} + state => panic!("invalid state for peer: {state:?}"), + } } #[tokio::test] #[should_panic] #[cfg(debug_assertions)] async fn open_substream_accepted() { - use tokio::sync::oneshot; + use tokio::sync::oneshot; - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let (peer, _service, _receiver) = add_peer(); - let (shutdown, _rx) = oneshot::channel(); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let (peer, _service, _receiver) = add_peer(); + let (shutdown, _rx) = oneshot::channel(); - notif.peers.insert(peer, PeerContext { state: PeerState::Open { shutdown } }); + notif.peers.insert( + peer, + PeerContext { + state: PeerState::Open { shutdown }, + }, + ); - // try to accept a closed substream - notif.on_close_substream(peer).await; + // try to accept a closed substream + notif.on_close_substream(peer).await; - assert!(notif.on_validation_result(peer, ValidationResult::Accept).await.is_err()); + assert!(notif.on_validation_result(peer, ValidationResult::Accept).await.is_err()); } #[tokio::test] #[should_panic] #[cfg(debug_assertions)] async fn open_substream_rejected() { - let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); - let (peer, _service, _receiver) = add_peer(); - let (shutdown, _rx) = oneshot::channel(); + let (mut notif, _handle, _sender, _tx) = make_notification_protocol(); + let (peer, _service, _receiver) = add_peer(); + let (shutdown, _rx) = oneshot::channel(); - notif.peers.insert(peer, PeerContext { state: PeerState::Open { shutdown } }); + notif.peers.insert( + peer, + PeerContext { + state: PeerState::Open { shutdown }, + }, + ); - // try to reject a closed substream - notif.on_close_substream(peer).await; + // try to reject a closed substream + notif.on_close_substream(peer).await; - assert!(notif.on_validation_result(peer, ValidationResult::Reject).await.is_err()); + assert!(notif.on_validation_result(peer, ValidationResult::Reject).await.is_err()); } diff --git a/src/protocol/notification/types.rs b/src/protocol/notification/types.rs index 04c6c9d8..114751e8 100644 --- a/src/protocol/notification/types.rs +++ b/src/protocol/notification/types.rs @@ -19,7 +19,7 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::notification::handle::NotificationSink, types::protocol::ProtocolName, PeerId, + protocol::notification::handle::NotificationSink, types::protocol::ProtocolName, PeerId, }; use bytes::BytesMut; @@ -36,185 +36,185 @@ pub(super) const ASYNC_CHANNEL_SIZE: usize = 8; /// Direction of the connection. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Direction { - /// Connection is considered inbound, i.e., it was initiated by the remote node. - Inbound, + /// Connection is considered inbound, i.e., it was initiated by the remote node. + Inbound, - /// Connection is considered outbound, i.e., it was initiated by the local node. - Outbound, + /// Connection is considered outbound, i.e., it was initiated by the local node. + Outbound, } /// Validation result. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ValidationResult { - /// Accept the inbound substream. - Accept, + /// Accept the inbound substream. + Accept, - /// Reject the inbound substream. - Reject, + /// Reject the inbound substream. + Reject, } /// Notification error. #[derive(Debug, Clone, PartialEq, Eq)] pub enum NotificationError { - /// Remote rejected the substream. - Rejected, + /// Remote rejected the substream. + Rejected, - /// Connection to peer doesn't exist. - NoConnection, + /// Connection to peer doesn't exist. + NoConnection, - /// Synchronous notification channel is clogged. - ChannelClogged, + /// Synchronous notification channel is clogged. + ChannelClogged, - /// Validation for a previous substream still pending. - ValidationPending, + /// Validation for a previous substream still pending. + ValidationPending, - /// Failed to dial peer. - DialFailure, + /// Failed to dial peer. + DialFailure, - /// Notification protocol has been closed. - EssentialTaskClosed, + /// Notification protocol has been closed. + EssentialTaskClosed, } /// Notification events. pub(crate) enum InnerNotificationEvent { - /// Validate substream. - ValidateSubstream { - /// Protocol name. - protocol: ProtocolName, + /// Validate substream. + ValidateSubstream { + /// Protocol name. + protocol: ProtocolName, - /// Fallback, if the substream was negotiated using a fallback protocol. - fallback: Option, + /// Fallback, if the substream was negotiated using a fallback protocol. + fallback: Option, - /// Peer ID. - peer: PeerId, + /// Peer ID. + peer: PeerId, - /// Handshake. - handshake: Vec, + /// Handshake. + handshake: Vec, - /// `oneshot::Sender` for sending the validation result back to the protocol. - tx: oneshot::Sender, - }, + /// `oneshot::Sender` for sending the validation result back to the protocol. + tx: oneshot::Sender, + }, - /// Notification stream opened. - NotificationStreamOpened { - /// Protocol name. - protocol: ProtocolName, + /// Notification stream opened. + NotificationStreamOpened { + /// Protocol name. + protocol: ProtocolName, - /// Fallback, if the substream was negotiated using a fallback protocol. - fallback: Option, + /// Fallback, if the substream was negotiated using a fallback protocol. + fallback: Option, - /// Direction of the substream. - direction: Direction, + /// Direction of the substream. + direction: Direction, - /// Peer ID. - peer: PeerId, + /// Peer ID. + peer: PeerId, - /// Handshake. - handshake: Vec, + /// Handshake. + handshake: Vec, - /// Notification sink. - sink: NotificationSink, - }, + /// Notification sink. + sink: NotificationSink, + }, - /// Notification stream closed. - NotificationStreamClosed { - /// Peer ID. - peer: PeerId, - }, + /// Notification stream closed. + NotificationStreamClosed { + /// Peer ID. + peer: PeerId, + }, - /// Failed to open notification stream. - NotificationStreamOpenFailure { - /// Peer ID. - peer: PeerId, + /// Failed to open notification stream. + NotificationStreamOpenFailure { + /// Peer ID. + peer: PeerId, - /// Error. - error: NotificationError, - }, + /// Error. + error: NotificationError, + }, } /// Notification events. #[derive(Debug, Clone, PartialEq, Eq)] pub enum NotificationEvent { - /// Validate substream. - ValidateSubstream { - /// Protocol name. - protocol: ProtocolName, - - /// Fallback, if the substream was negotiated using a fallback protocol. - fallback: Option, - - /// Peer ID. - peer: PeerId, - - /// Handshake. - handshake: Vec, - }, - - /// Notification stream opened. - NotificationStreamOpened { - /// Protocol name. - protocol: ProtocolName, - - /// Fallback, if the substream was negotiated using a fallback protocol. - fallback: Option, - - /// Direction of the substream. - /// - /// [`Direction::Inbound`](crate::protocol::Direction::Outbound) indicates that the - /// substream was opened by the remote peer and - /// [`Direction::Outbound`](crate::protocol::Direction::Outbound) that it was - /// opened by the local node. - direction: Direction, - - /// Peer ID. - peer: PeerId, - - /// Handshake. - handshake: Vec, - }, - - /// Notification stream closed. - NotificationStreamClosed { - /// Peer ID. - peer: PeerId, - }, - - /// Failed to open notification stream. - NotificationStreamOpenFailure { - /// Peer ID. - peer: PeerId, - - /// Error. - error: NotificationError, - }, - - /// Notification received. - NotificationReceived { - /// Peer ID. - peer: PeerId, - - /// Notification. - notification: BytesMut, - }, + /// Validate substream. + ValidateSubstream { + /// Protocol name. + protocol: ProtocolName, + + /// Fallback, if the substream was negotiated using a fallback protocol. + fallback: Option, + + /// Peer ID. + peer: PeerId, + + /// Handshake. + handshake: Vec, + }, + + /// Notification stream opened. + NotificationStreamOpened { + /// Protocol name. + protocol: ProtocolName, + + /// Fallback, if the substream was negotiated using a fallback protocol. + fallback: Option, + + /// Direction of the substream. + /// + /// [`Direction::Inbound`](crate::protocol::Direction::Outbound) indicates that the + /// substream was opened by the remote peer and + /// [`Direction::Outbound`](crate::protocol::Direction::Outbound) that it was + /// opened by the local node. + direction: Direction, + + /// Peer ID. + peer: PeerId, + + /// Handshake. + handshake: Vec, + }, + + /// Notification stream closed. + NotificationStreamClosed { + /// Peer ID. + peer: PeerId, + }, + + /// Failed to open notification stream. + NotificationStreamOpenFailure { + /// Peer ID. + peer: PeerId, + + /// Error. + error: NotificationError, + }, + + /// Notification received. + NotificationReceived { + /// Peer ID. + peer: PeerId, + + /// Notification. + notification: BytesMut, + }, } /// Notification commands sent to the protocol. pub(crate) enum NotificationCommand { - /// Open substreams to one or more peers. - OpenSubstream { - /// Peer IDs. - peers: HashSet, - }, - - /// Close substreams to one or more peers. - CloseSubstream { - /// Peer IDs. - peers: HashSet, - }, - - /// Force close the connection because notification channel is clogged. - ForceClose { - /// Peer to disconnect. - peer: PeerId, - }, + /// Open substreams to one or more peers. + OpenSubstream { + /// Peer IDs. + peers: HashSet, + }, + + /// Close substreams to one or more peers. + CloseSubstream { + /// Peer IDs. + peers: HashSet, + }, + + /// Force close the connection because notification channel is clogged. + ForceClose { + /// Peer to disconnect. + peer: PeerId, + }, } diff --git a/src/protocol/protocol_set.rs b/src/protocol/protocol_set.rs index d480d406..1988d95a 100644 --- a/src/protocol/protocol_set.rs +++ b/src/protocol/protocol_set.rs @@ -19,19 +19,19 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, - error::Error, - protocol::{ - connection::{ConnectionHandle, Permit}, - Direction, TransportEvent, - }, - substream::Substream, - transport::{ - manager::{ProtocolContext, TransportManagerEvent}, - Endpoint, - }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - PeerId, + codec::ProtocolCodec, + error::Error, + protocol::{ + connection::{ConnectionHandle, Permit}, + Direction, TransportEvent, + }, + substream::Substream, + transport::{ + manager::{ProtocolContext, TransportManagerEvent}, + Endpoint, + }, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + PeerId, }; use futures::{stream::FuturesUnordered, Stream, StreamExt}; @@ -39,14 +39,14 @@ use multiaddr::Multiaddr; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::HashMap, - fmt::Debug, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::{Context, Poll}, + collections::HashMap, + fmt::Debug, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, }; /// Logging target for the file. @@ -55,146 +55,152 @@ const LOG_TARGET: &str = "litep2p::protocol-set"; /// Events emitted by the underlying transport protocols. #[derive(Debug)] pub enum InnerTransportEvent { - /// Connection established to `peer`. - ConnectionEstablished { - /// Peer ID. - peer: PeerId, - - /// Connection ID. - connection: ConnectionId, - - /// Endpoint. - endpoint: Endpoint, - - /// Handle for communicating with the connection. - sender: ConnectionHandle, - }, - - /// Connection closed. - ConnectionClosed { - /// Peer ID. - peer: PeerId, - - /// Connection ID. - connection: ConnectionId, - }, - - /// Failed to dial peer. - /// - /// This is reported to that protocol which initiated the connection. - DialFailure { - /// Peer ID. - peer: PeerId, - - /// Dialed address. - address: Multiaddr, - }, - - /// Substream opened for `peer`. - SubstreamOpened { - /// Peer ID. - peer: PeerId, - - /// Protocol name. - /// - /// One protocol handler may handle multiple sub-protocols (such as `/ipfs/identify/1.0.0` - /// and `/ipfs/identify/push/1.0.0`) or it may have aliases which should be handled by - /// the same protocol handler. When the substream is sent from transport to the protocol - /// handler, the protocol name that was used to negotiate the substream is also sent so - /// the protocol can handle the substream appropriately. - protocol: ProtocolName, - - /// Fallback name. - /// - /// If the substream was negotiated using a fallback name of the main protocol, - /// `fallback` is `Some`. - fallback: Option, - - /// Substream direction. - /// - /// Informs the protocol whether the substream is inbound (opened by the remote node) - /// or outbound (opened by the local node). This allows the protocol to distinguish - /// between the two types of substreams and execute correct code for the substream. - /// - /// Outbound substreams also contain the substream ID which allows the protocol to - /// distinguish between different outbound substreams. - direction: Direction, - - /// Substream. - substream: Substream, - }, - - /// Failed to open substream. - /// - /// Substream open failures are reported only for outbound substreams. - SubstreamOpenFailure { - /// Substream ID. - substream: SubstreamId, - - /// Error that occurred when the substream was being opened. - error: Error, - }, + /// Connection established to `peer`. + ConnectionEstablished { + /// Peer ID. + peer: PeerId, + + /// Connection ID. + connection: ConnectionId, + + /// Endpoint. + endpoint: Endpoint, + + /// Handle for communicating with the connection. + sender: ConnectionHandle, + }, + + /// Connection closed. + ConnectionClosed { + /// Peer ID. + peer: PeerId, + + /// Connection ID. + connection: ConnectionId, + }, + + /// Failed to dial peer. + /// + /// This is reported to that protocol which initiated the connection. + DialFailure { + /// Peer ID. + peer: PeerId, + + /// Dialed address. + address: Multiaddr, + }, + + /// Substream opened for `peer`. + SubstreamOpened { + /// Peer ID. + peer: PeerId, + + /// Protocol name. + /// + /// One protocol handler may handle multiple sub-protocols (such as `/ipfs/identify/1.0.0` + /// and `/ipfs/identify/push/1.0.0`) or it may have aliases which should be handled by + /// the same protocol handler. When the substream is sent from transport to the protocol + /// handler, the protocol name that was used to negotiate the substream is also sent so + /// the protocol can handle the substream appropriately. + protocol: ProtocolName, + + /// Fallback name. + /// + /// If the substream was negotiated using a fallback name of the main protocol, + /// `fallback` is `Some`. + fallback: Option, + + /// Substream direction. + /// + /// Informs the protocol whether the substream is inbound (opened by the remote node) + /// or outbound (opened by the local node). This allows the protocol to distinguish + /// between the two types of substreams and execute correct code for the substream. + /// + /// Outbound substreams also contain the substream ID which allows the protocol to + /// distinguish between different outbound substreams. + direction: Direction, + + /// Substream. + substream: Substream, + }, + + /// Failed to open substream. + /// + /// Substream open failures are reported only for outbound substreams. + SubstreamOpenFailure { + /// Substream ID. + substream: SubstreamId, + + /// Error that occurred when the substream was being opened. + error: Error, + }, } impl From for TransportEvent { - fn from(event: InnerTransportEvent) -> Self { - match event { - InnerTransportEvent::DialFailure { peer, address } => - TransportEvent::DialFailure { peer, address }, - InnerTransportEvent::SubstreamOpened { - peer, - protocol, - fallback, - direction, - substream, - } => TransportEvent::SubstreamOpened { peer, protocol, fallback, direction, substream }, - InnerTransportEvent::SubstreamOpenFailure { substream, error } => - TransportEvent::SubstreamOpenFailure { substream, error }, - event => panic!("cannot convert {event:?}"), - } - } + fn from(event: InnerTransportEvent) -> Self { + match event { + InnerTransportEvent::DialFailure { peer, address } => + TransportEvent::DialFailure { peer, address }, + InnerTransportEvent::SubstreamOpened { + peer, + protocol, + fallback, + direction, + substream, + } => TransportEvent::SubstreamOpened { + peer, + protocol, + fallback, + direction, + substream, + }, + InnerTransportEvent::SubstreamOpenFailure { substream, error } => + TransportEvent::SubstreamOpenFailure { substream, error }, + event => panic!("cannot convert {event:?}"), + } + } } /// Events emitted by the installed protocols to transport. #[derive(Debug)] pub enum ProtocolCommand { - /// Open substream. - OpenSubstream { - /// Protocol name. - protocol: ProtocolName, - - /// Fallback names. - /// - /// If the protocol has changed its name but wishes to suppor the old name(s), it must - /// provide the old protocol names in `fallback_names`. These are fed into - /// `multistream-select` which them attempts to negotiate a protocol for the substream - /// using one of the provided names and if the substream is negotiated successfully, will - /// report back the actual protocol name that was negotiated, in case the protocol - /// needs to deal with the old version of the protocol in different way compared to - /// the new version. - fallback_names: Vec, - - /// Substream ID. - /// - /// Protocol allocates an ephemeral ID for outbound substreams which allows it to track - /// the state of its pending substream. The ID is given back to protocol in - /// [`TransportEvent::SubstreamOpened`]/[`TransportEvent::SubstreamOpenFailure`]. - /// - /// This allows the protocol to distinguish inbound substreams from outbound substreams - /// and associate incoming substreams with whatever logic it has. - substream_id: SubstreamId, - - /// Connection permit. - /// - /// `Permit` allows the connection to be kept open while the permit is held and it is given - /// to the substream to hold once it has been opened. When the substream is dropped, the - /// permit is dropped and the connection may be closed if no other permit is being - /// held. - permit: Permit, - }, - - /// Forcibly close the connection, even if other protocols have substreams open over it. - ForceClose, + /// Open substream. + OpenSubstream { + /// Protocol name. + protocol: ProtocolName, + + /// Fallback names. + /// + /// If the protocol has changed its name but wishes to suppor the old name(s), it must + /// provide the old protocol names in `fallback_names`. These are fed into + /// `multistream-select` which them attempts to negotiate a protocol for the substream + /// using one of the provided names and if the substream is negotiated successfully, will + /// report back the actual protocol name that was negotiated, in case the protocol + /// needs to deal with the old version of the protocol in different way compared to + /// the new version. + fallback_names: Vec, + + /// Substream ID. + /// + /// Protocol allocates an ephemeral ID for outbound substreams which allows it to track + /// the state of its pending substream. The ID is given back to protocol in + /// [`TransportEvent::SubstreamOpened`]/[`TransportEvent::SubstreamOpenFailure`]. + /// + /// This allows the protocol to distinguish inbound substreams from outbound substreams + /// and associate incoming substreams with whatever logic it has. + substream_id: SubstreamId, + + /// Connection permit. + /// + /// `Permit` allows the connection to be kept open while the permit is held and it is given + /// to the substream to hold once it has been opened. When the substream is dropped, the + /// permit is dropped and the connection may be closed if no other permit is being + /// held. + permit: Permit, + }, + + /// Forcibly close the connection, even if other protocols have substreams open over it. + ForceClose, } /// Supported protocol information. @@ -202,344 +208,354 @@ pub enum ProtocolCommand { /// Each connection gets a copy of [`ProtocolSet`] which allows it to interact /// directly with installed protocols. pub struct ProtocolSet { - /// Installed protocols. - pub(crate) protocols: HashMap, - mgr_tx: Sender, - connection: ConnectionHandle, - rx: Receiver, - next_substream_id: Arc, - fallback_names: HashMap, + /// Installed protocols. + pub(crate) protocols: HashMap, + mgr_tx: Sender, + connection: ConnectionHandle, + rx: Receiver, + next_substream_id: Arc, + fallback_names: HashMap, } impl ProtocolSet { - pub fn new( - connection_id: ConnectionId, - mgr_tx: Sender, - next_substream_id: Arc, - protocols: HashMap, - ) -> Self { - let (tx, rx) = channel(256); - - let fallback_names = protocols - .iter() - .map(|(protocol, context)| { - context - .fallback_names - .iter() - .map(|fallback| (fallback.clone(), protocol.clone())) - .collect::>() - }) - .flatten() - .collect(); - - ProtocolSet { - rx, - mgr_tx, - protocols, - next_substream_id, - fallback_names, - connection: ConnectionHandle::new(connection_id, tx), - } - } - - /// Try to acquire permit to keep the connection open. - pub fn try_get_permit(&mut self) -> Option { - self.connection.try_get_permit() - } - - /// Get next substream ID. - pub fn next_substream_id(&self) -> SubstreamId { - SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed)) - } - - /// Get the list of all supported protocols. - pub fn protocols(&self) -> Vec { - self.protocols - .keys() - .cloned() - .chain(self.fallback_names.keys().cloned()) - .collect() - } - - /// Report to `protocol` that substream was opened for `peer`. - pub async fn report_substream_open( - &mut self, - peer: PeerId, - protocol: ProtocolName, - direction: Direction, - substream: Substream, - ) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, %protocol, ?peer, ?direction, "substream opened"); - - let (protocol, fallback) = match self.fallback_names.get(&protocol) { - Some(main_protocol) => (main_protocol.clone(), Some(protocol)), - None => (protocol, None), - }; - - self.protocols - .get_mut(&protocol) - .ok_or(Error::ProtocolNotSupported(protocol.to_string()))? - .tx - .send(InnerTransportEvent::SubstreamOpened { - peer, - protocol: protocol.clone(), - fallback, - direction, - substream, - }) - .await - .map_err(From::from) - } - - /// Get codec used by the protocol. - pub fn protocol_codec(&self, protocol: &ProtocolName) -> ProtocolCodec { - // NOTE: `protocol` must exist in `self.protocol` as it was negotiated - // using the protocols from this set - self.protocols - .get(self.fallback_names.get(&protocol).map_or(protocol, |protocol| protocol)) - .expect("protocol to exist") - .codec - } - - /// Report to `protocol` that connection failed to open substream for `peer`. - pub async fn report_substream_open_failure( - &mut self, - protocol: ProtocolName, - substream: SubstreamId, - error: Error, - ) -> crate::Result<()> { - tracing::debug!( - target: LOG_TARGET, - %protocol, - ?substream, - ?error, - "failed to open substream", - ); - - self.protocols - .get_mut(&protocol) - .ok_or(Error::ProtocolNotSupported(protocol.to_string()))? - .tx - .send(InnerTransportEvent::SubstreamOpenFailure { substream, error }) - .await - .map_err(From::from) - } - - /// Report to protocols that a connection was established. - pub(crate) async fn report_connection_established( - &mut self, - peer: PeerId, - endpoint: Endpoint, - ) -> crate::Result<()> { - let connection_handle = self.connection.downgrade(); - let mut futures = self - .protocols - .iter() - .map(|(_, sender)| { - let endpoint = endpoint.clone(); - let connection_handle = connection_handle.clone(); - - async move { - sender - .tx - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: endpoint.connection_id(), - endpoint, - sender: connection_handle, - }) - .await - } - }) - .collect::>(); - - while !futures.is_empty() { - if let Some(Err(error)) = futures.next().await { - return Err(error.into()); - } - } - - Ok(()) - } - - /// Report to protocols that a connection was closed. - pub(crate) async fn report_connection_closed( - &mut self, - peer: PeerId, - connection_id: ConnectionId, - ) -> crate::Result<()> { - let mut futures = self - .protocols - .iter() - .map(|(_, sender)| async move { - sender - .tx - .send(InnerTransportEvent::ConnectionClosed { peer, connection: connection_id }) - .await - }) - .collect::>(); - - while !futures.is_empty() { - if let Some(Err(error)) = futures.next().await { - return Err(error.into()); - } - } - - self.mgr_tx - .send(TransportManagerEvent::ConnectionClosed { peer, connection: connection_id }) - .await - .map_err(From::from) - } + pub fn new( + connection_id: ConnectionId, + mgr_tx: Sender, + next_substream_id: Arc, + protocols: HashMap, + ) -> Self { + let (tx, rx) = channel(256); + + let fallback_names = protocols + .iter() + .map(|(protocol, context)| { + context + .fallback_names + .iter() + .map(|fallback| (fallback.clone(), protocol.clone())) + .collect::>() + }) + .flatten() + .collect(); + + ProtocolSet { + rx, + mgr_tx, + protocols, + next_substream_id, + fallback_names, + connection: ConnectionHandle::new(connection_id, tx), + } + } + + /// Try to acquire permit to keep the connection open. + pub fn try_get_permit(&mut self) -> Option { + self.connection.try_get_permit() + } + + /// Get next substream ID. + pub fn next_substream_id(&self) -> SubstreamId { + SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed)) + } + + /// Get the list of all supported protocols. + pub fn protocols(&self) -> Vec { + self.protocols + .keys() + .cloned() + .chain(self.fallback_names.keys().cloned()) + .collect() + } + + /// Report to `protocol` that substream was opened for `peer`. + pub async fn report_substream_open( + &mut self, + peer: PeerId, + protocol: ProtocolName, + direction: Direction, + substream: Substream, + ) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, %protocol, ?peer, ?direction, "substream opened"); + + let (protocol, fallback) = match self.fallback_names.get(&protocol) { + Some(main_protocol) => (main_protocol.clone(), Some(protocol)), + None => (protocol, None), + }; + + self.protocols + .get_mut(&protocol) + .ok_or(Error::ProtocolNotSupported(protocol.to_string()))? + .tx + .send(InnerTransportEvent::SubstreamOpened { + peer, + protocol: protocol.clone(), + fallback, + direction, + substream, + }) + .await + .map_err(From::from) + } + + /// Get codec used by the protocol. + pub fn protocol_codec(&self, protocol: &ProtocolName) -> ProtocolCodec { + // NOTE: `protocol` must exist in `self.protocol` as it was negotiated + // using the protocols from this set + self.protocols + .get(self.fallback_names.get(&protocol).map_or(protocol, |protocol| protocol)) + .expect("protocol to exist") + .codec + } + + /// Report to `protocol` that connection failed to open substream for `peer`. + pub async fn report_substream_open_failure( + &mut self, + protocol: ProtocolName, + substream: SubstreamId, + error: Error, + ) -> crate::Result<()> { + tracing::debug!( + target: LOG_TARGET, + %protocol, + ?substream, + ?error, + "failed to open substream", + ); + + self.protocols + .get_mut(&protocol) + .ok_or(Error::ProtocolNotSupported(protocol.to_string()))? + .tx + .send(InnerTransportEvent::SubstreamOpenFailure { substream, error }) + .await + .map_err(From::from) + } + + /// Report to protocols that a connection was established. + pub(crate) async fn report_connection_established( + &mut self, + peer: PeerId, + endpoint: Endpoint, + ) -> crate::Result<()> { + let connection_handle = self.connection.downgrade(); + let mut futures = self + .protocols + .iter() + .map(|(_, sender)| { + let endpoint = endpoint.clone(); + let connection_handle = connection_handle.clone(); + + async move { + sender + .tx + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: endpoint.connection_id(), + endpoint, + sender: connection_handle, + }) + .await + } + }) + .collect::>(); + + while !futures.is_empty() { + if let Some(Err(error)) = futures.next().await { + return Err(error.into()); + } + } + + Ok(()) + } + + /// Report to protocols that a connection was closed. + pub(crate) async fn report_connection_closed( + &mut self, + peer: PeerId, + connection_id: ConnectionId, + ) -> crate::Result<()> { + let mut futures = self + .protocols + .iter() + .map(|(_, sender)| async move { + sender + .tx + .send(InnerTransportEvent::ConnectionClosed { + peer, + connection: connection_id, + }) + .await + }) + .collect::>(); + + while !futures.is_empty() { + if let Some(Err(error)) = futures.next().await { + return Err(error.into()); + } + } + + self.mgr_tx + .send(TransportManagerEvent::ConnectionClosed { + peer, + connection: connection_id, + }) + .await + .map_err(From::from) + } } impl Stream for ProtocolSet { - type Item = ProtocolCommand; + type Item = ProtocolCommand; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) - } + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } } #[cfg(test)] mod tests { - use super::*; - use crate::mock::substream::MockSubstream; - use std::collections::HashSet; - - #[tokio::test] - async fn fallback_is_provided() { - let (tx, _rx) = channel(64); - let (tx1, _rx1) = channel(64); - - let mut protocol_set = ProtocolSet::new( - ConnectionId::from(0usize), - tx, - Default::default(), - HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx1, - codec: ProtocolCodec::Identity(32), - fallback_names: vec![ - ProtocolName::from("/notif/1/fallback/1"), - ProtocolName::from("/notif/1/fallback/2"), - ], - }, - )]), - ); - - let expected_protocols = HashSet::from([ - ProtocolName::from("/notif/1"), - ProtocolName::from("/notif/1/fallback/1"), - ProtocolName::from("/notif/1/fallback/2"), - ]); - - for protocol in protocol_set.protocols().iter() { - assert!(expected_protocols.contains(protocol)); - } - - protocol_set - .report_substream_open( - PeerId::random(), - ProtocolName::from("/notif/1/fallback/2"), - Direction::Inbound, - Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - ) - .await - .unwrap(); - } - - #[tokio::test] - async fn main_protocol_reported_if_main_protocol_negotiated() { - let (tx, _rx) = channel(64); - let (tx1, mut rx1) = channel(64); - - let mut protocol_set = ProtocolSet::new( - ConnectionId::from(0usize), - tx, - Default::default(), - HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx1, - codec: ProtocolCodec::Identity(32), - fallback_names: vec![ - ProtocolName::from("/notif/1/fallback/1"), - ProtocolName::from("/notif/1/fallback/2"), - ], - }, - )]), - ); - - protocol_set - .report_substream_open( - PeerId::random(), - ProtocolName::from("/notif/1"), - Direction::Inbound, - Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - ) - .await - .unwrap(); - - match rx1.recv().await.unwrap() { - InnerTransportEvent::SubstreamOpened { protocol, fallback, .. } => { - assert!(fallback.is_none()); - assert_eq!(protocol, ProtocolName::from("/notif/1")); - }, - _ => panic!("invalid event received"), - } - } - - #[tokio::test] - async fn fallback_is_reported_to_protocol() { - let (tx, _rx) = channel(64); - let (tx1, mut rx1) = channel(64); - - let mut protocol_set = ProtocolSet::new( - ConnectionId::from(0usize), - tx, - Default::default(), - HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx1, - codec: ProtocolCodec::Identity(32), - fallback_names: vec![ - ProtocolName::from("/notif/1/fallback/1"), - ProtocolName::from("/notif/1/fallback/2"), - ], - }, - )]), - ); - - protocol_set - .report_substream_open( - PeerId::random(), - ProtocolName::from("/notif/1/fallback/2"), - Direction::Inbound, - Substream::new_mock( - PeerId::random(), - SubstreamId::from(0usize), - Box::new(MockSubstream::new()), - ), - ) - .await - .unwrap(); - - match rx1.recv().await.unwrap() { - InnerTransportEvent::SubstreamOpened { protocol, fallback, .. } => { - assert_eq!(fallback, Some(ProtocolName::from("/notif/1/fallback/2"))); - assert_eq!(protocol, ProtocolName::from("/notif/1")); - }, - _ => panic!("invalid event received"), - } - } + use super::*; + use crate::mock::substream::MockSubstream; + use std::collections::HashSet; + + #[tokio::test] + async fn fallback_is_provided() { + let (tx, _rx) = channel(64); + let (tx1, _rx1) = channel(64); + + let mut protocol_set = ProtocolSet::new( + ConnectionId::from(0usize), + tx, + Default::default(), + HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx1, + codec: ProtocolCodec::Identity(32), + fallback_names: vec![ + ProtocolName::from("/notif/1/fallback/1"), + ProtocolName::from("/notif/1/fallback/2"), + ], + }, + )]), + ); + + let expected_protocols = HashSet::from([ + ProtocolName::from("/notif/1"), + ProtocolName::from("/notif/1/fallback/1"), + ProtocolName::from("/notif/1/fallback/2"), + ]); + + for protocol in protocol_set.protocols().iter() { + assert!(expected_protocols.contains(protocol)); + } + + protocol_set + .report_substream_open( + PeerId::random(), + ProtocolName::from("/notif/1/fallback/2"), + Direction::Inbound, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + ) + .await + .unwrap(); + } + + #[tokio::test] + async fn main_protocol_reported_if_main_protocol_negotiated() { + let (tx, _rx) = channel(64); + let (tx1, mut rx1) = channel(64); + + let mut protocol_set = ProtocolSet::new( + ConnectionId::from(0usize), + tx, + Default::default(), + HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx1, + codec: ProtocolCodec::Identity(32), + fallback_names: vec![ + ProtocolName::from("/notif/1/fallback/1"), + ProtocolName::from("/notif/1/fallback/2"), + ], + }, + )]), + ); + + protocol_set + .report_substream_open( + PeerId::random(), + ProtocolName::from("/notif/1"), + Direction::Inbound, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + ) + .await + .unwrap(); + + match rx1.recv().await.unwrap() { + InnerTransportEvent::SubstreamOpened { + protocol, fallback, .. + } => { + assert!(fallback.is_none()); + assert_eq!(protocol, ProtocolName::from("/notif/1")); + } + _ => panic!("invalid event received"), + } + } + + #[tokio::test] + async fn fallback_is_reported_to_protocol() { + let (tx, _rx) = channel(64); + let (tx1, mut rx1) = channel(64); + + let mut protocol_set = ProtocolSet::new( + ConnectionId::from(0usize), + tx, + Default::default(), + HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx1, + codec: ProtocolCodec::Identity(32), + fallback_names: vec![ + ProtocolName::from("/notif/1/fallback/1"), + ProtocolName::from("/notif/1/fallback/2"), + ], + }, + )]), + ); + + protocol_set + .report_substream_open( + PeerId::random(), + ProtocolName::from("/notif/1/fallback/2"), + Direction::Inbound, + Substream::new_mock( + PeerId::random(), + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + ) + .await + .unwrap(); + + match rx1.recv().await.unwrap() { + InnerTransportEvent::SubstreamOpened { + protocol, fallback, .. + } => { + assert_eq!(fallback, Some(ProtocolName::from("/notif/1/fallback/2"))); + assert_eq!(protocol, ProtocolName::from("/notif/1")); + } + _ => panic!("invalid event received"), + } + } } diff --git a/src/protocol/request_response/config.rs b/src/protocol/request_response/config.rs index ca02ca7e..a44b1238 100644 --- a/src/protocol/request_response/config.rs +++ b/src/protocol/request_response/config.rs @@ -19,153 +19,153 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, - protocol::request_response::{ - handle::{InnerRequestResponseEvent, RequestResponseCommand, RequestResponseHandle}, - REQUEST_TIMEOUT, - }, - types::protocol::ProtocolName, - DEFAULT_CHANNEL_SIZE, + codec::ProtocolCodec, + protocol::request_response::{ + handle::{InnerRequestResponseEvent, RequestResponseCommand, RequestResponseHandle}, + REQUEST_TIMEOUT, + }, + types::protocol::ProtocolName, + DEFAULT_CHANNEL_SIZE, }; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - sync::{atomic::AtomicUsize, Arc}, - time::Duration, + sync::{atomic::AtomicUsize, Arc}, + time::Duration, }; /// Request-response protocol configuration. pub struct Config { - /// Protocol name. - pub(crate) protocol_name: ProtocolName, + /// Protocol name. + pub(crate) protocol_name: ProtocolName, - /// Fallback names for the main protocol name. - pub(crate) fallback_names: Vec, + /// Fallback names for the main protocol name. + pub(crate) fallback_names: Vec, - /// Timeout for outbound requests. - pub(crate) timeout: Duration, + /// Timeout for outbound requests. + pub(crate) timeout: Duration, - /// Codec used by the protocol. - pub(crate) codec: ProtocolCodec, + /// Codec used by the protocol. + pub(crate) codec: ProtocolCodec, - /// TX channel for sending events to the user protocol. - pub(super) event_tx: Sender, + /// TX channel for sending events to the user protocol. + pub(super) event_tx: Sender, - /// RX channel for receiving commands from the user protocol. - pub(crate) command_rx: Receiver, + /// RX channel for receiving commands from the user protocol. + pub(crate) command_rx: Receiver, - /// Next ephemeral request ID. - pub(crate) next_request_id: Arc, + /// Next ephemeral request ID. + pub(crate) next_request_id: Arc, - /// Maximum number of concurrent inbound requests. - pub(crate) max_concurrent_inbound_request: Option, + /// Maximum number of concurrent inbound requests. + pub(crate) max_concurrent_inbound_request: Option, } impl Config { - /// Create new [`Config`]. - pub fn new( - protocol_name: ProtocolName, - fallback_names: Vec, - max_message_size: usize, - timeout: Duration, - max_concurrent_inbound_request: Option, - ) -> (Self, RequestResponseHandle) { - let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); - let (command_tx, command_rx) = channel(DEFAULT_CHANNEL_SIZE); - let next_request_id = Default::default(); - let handle = RequestResponseHandle::new(event_rx, command_tx, Arc::clone(&next_request_id)); - - ( - Self { - event_tx, - command_rx, - protocol_name, - fallback_names, - next_request_id, - timeout, - max_concurrent_inbound_request, - codec: ProtocolCodec::UnsignedVarint(Some(max_message_size)), - }, - handle, - ) - } - - /// Get protocol name. - pub(crate) fn protocol_name(&self) -> &ProtocolName { - &self.protocol_name - } + /// Create new [`Config`]. + pub fn new( + protocol_name: ProtocolName, + fallback_names: Vec, + max_message_size: usize, + timeout: Duration, + max_concurrent_inbound_request: Option, + ) -> (Self, RequestResponseHandle) { + let (event_tx, event_rx) = channel(DEFAULT_CHANNEL_SIZE); + let (command_tx, command_rx) = channel(DEFAULT_CHANNEL_SIZE); + let next_request_id = Default::default(); + let handle = RequestResponseHandle::new(event_rx, command_tx, Arc::clone(&next_request_id)); + + ( + Self { + event_tx, + command_rx, + protocol_name, + fallback_names, + next_request_id, + timeout, + max_concurrent_inbound_request, + codec: ProtocolCodec::UnsignedVarint(Some(max_message_size)), + }, + handle, + ) + } + + /// Get protocol name. + pub(crate) fn protocol_name(&self) -> &ProtocolName { + &self.protocol_name + } } /// Builder for [`Config`]. pub struct ConfigBuilder { - /// Protocol name. - pub(crate) protocol_name: ProtocolName, + /// Protocol name. + pub(crate) protocol_name: ProtocolName, - /// Fallback names for the main protocol name. - pub(crate) fallback_names: Vec, + /// Fallback names for the main protocol name. + pub(crate) fallback_names: Vec, - /// Maximum message size. - max_message_size: Option, + /// Maximum message size. + max_message_size: Option, - /// Timeout for outbound requests. - timeout: Option, + /// Timeout for outbound requests. + timeout: Option, - /// Maximum number of concurrent inbound requests. - max_concurrent_inbound_request: Option, + /// Maximum number of concurrent inbound requests. + max_concurrent_inbound_request: Option, } impl ConfigBuilder { - /// Create new [`ConfigBuilder`]. - pub fn new(protocol_name: ProtocolName) -> Self { - Self { - protocol_name, - fallback_names: Vec::new(), - max_message_size: None, - timeout: Some(REQUEST_TIMEOUT), - max_concurrent_inbound_request: None, - } - } - - /// Set maximum message size. - pub fn with_max_size(mut self, max_message_size: usize) -> Self { - self.max_message_size = Some(max_message_size); - self - } - - /// Set fallback names. - pub fn with_fallback_names(mut self, fallback_names: Vec) -> Self { - self.fallback_names = fallback_names; - self - } - - /// Set timeout for outbound requests. - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.timeout = Some(timeout); - self - } - - /// Specify the maximum number of concurrent inbound requests. By default the number of inbound - /// requests is not limited. - /// - /// If a new request is received while the number of inbound requests is already at a maximum, - /// the request is dropped. - pub fn with_max_concurrent_inbound_requests( - mut self, - max_concurrent_inbound_requests: usize, - ) -> Self { - self.max_concurrent_inbound_request = Some(max_concurrent_inbound_requests); - self - } - - /// Build [`Config`]. - pub fn build(mut self) -> (Config, RequestResponseHandle) { - Config::new( - self.protocol_name, - self.fallback_names, - self.max_message_size.take().expect("maximum message size to be set"), - self.timeout.take().expect("timeout to exist"), - self.max_concurrent_inbound_request, - ) - } + /// Create new [`ConfigBuilder`]. + pub fn new(protocol_name: ProtocolName) -> Self { + Self { + protocol_name, + fallback_names: Vec::new(), + max_message_size: None, + timeout: Some(REQUEST_TIMEOUT), + max_concurrent_inbound_request: None, + } + } + + /// Set maximum message size. + pub fn with_max_size(mut self, max_message_size: usize) -> Self { + self.max_message_size = Some(max_message_size); + self + } + + /// Set fallback names. + pub fn with_fallback_names(mut self, fallback_names: Vec) -> Self { + self.fallback_names = fallback_names; + self + } + + /// Set timeout for outbound requests. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Specify the maximum number of concurrent inbound requests. By default the number of inbound + /// requests is not limited. + /// + /// If a new request is received while the number of inbound requests is already at a maximum, + /// the request is dropped. + pub fn with_max_concurrent_inbound_requests( + mut self, + max_concurrent_inbound_requests: usize, + ) -> Self { + self.max_concurrent_inbound_request = Some(max_concurrent_inbound_requests); + self + } + + /// Build [`Config`]. + pub fn build(mut self) -> (Config, RequestResponseHandle) { + Config::new( + self.protocol_name, + self.fallback_names, + self.max_message_size.take().expect("maximum message size to be set"), + self.timeout.take().expect("timeout to exist"), + self.max_concurrent_inbound_request, + ) + } } diff --git a/src/protocol/request_response/handle.rs b/src/protocol/request_response/handle.rs index d33eb22c..f1802fae 100644 --- a/src/protocol/request_response/handle.rs +++ b/src/protocol/request_response/handle.rs @@ -19,24 +19,24 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - types::{protocol::ProtocolName, RequestId}, - Error, PeerId, + types::{protocol::ProtocolName, RequestId}, + Error, PeerId, }; use futures::channel; use tokio::sync::{ - mpsc::{Receiver, Sender}, - oneshot, + mpsc::{Receiver, Sender}, + oneshot, }; use std::{ - collections::HashMap, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::{Context, Poll}, + collections::HashMap, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, }; /// Logging target for the file. @@ -45,441 +45,463 @@ const LOG_TARGET: &str = "litep2p::request-response::handle"; /// Request-response error. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RequestResponseError { - /// Request was rejected. - Rejected, + /// Request was rejected. + Rejected, - /// Request was canceled by the local node. - Canceled, + /// Request was canceled by the local node. + Canceled, - /// Request timed out. - Timeout, + /// Request timed out. + Timeout, - /// Litep2p isn't connected to the peer. - NotConnected, + /// Litep2p isn't connected to the peer. + NotConnected, - /// Too large payload. - TooLargePayload, + /// Too large payload. + TooLargePayload, - /// Protocol not supported. - UnsupportedProtocol, + /// Protocol not supported. + UnsupportedProtocol, } /// Request-response events. pub(super) enum InnerRequestResponseEvent { - /// Request received from remote - RequestReceived { - /// Peer Id. - peer: PeerId, + /// Request received from remote + RequestReceived { + /// Peer Id. + peer: PeerId, - /// Fallback protocol, if the substream was negotiated using a fallback. - fallback: Option, + /// Fallback protocol, if the substream was negotiated using a fallback. + fallback: Option, - /// Request ID. - request_id: RequestId, + /// Request ID. + request_id: RequestId, - /// Received request. - request: Vec, + /// Received request. + request: Vec, - /// `oneshot::Sender` for response. - response_tx: oneshot::Sender<(Vec, Option>)>, - }, + /// `oneshot::Sender` for response. + response_tx: oneshot::Sender<(Vec, Option>)>, + }, - /// Response received. - ResponseReceived { - /// Peer Id. - peer: PeerId, + /// Response received. + ResponseReceived { + /// Peer Id. + peer: PeerId, - /// Fallback protocol, if the substream was negotiated using a fallback. - fallback: Option, + /// Fallback protocol, if the substream was negotiated using a fallback. + fallback: Option, - /// Request ID. - request_id: RequestId, + /// Request ID. + request_id: RequestId, - /// Received request. - response: Vec, - }, + /// Received request. + response: Vec, + }, - /// Request failed. - RequestFailed { - /// Peer Id. - peer: PeerId, + /// Request failed. + RequestFailed { + /// Peer Id. + peer: PeerId, - /// Request ID. - request_id: RequestId, + /// Request ID. + request_id: RequestId, - /// Request-response error. - error: RequestResponseError, - }, + /// Request-response error. + error: RequestResponseError, + }, } impl From for RequestResponseEvent { - fn from(event: InnerRequestResponseEvent) -> Self { - match event { - InnerRequestResponseEvent::ResponseReceived { - peer, - request_id, - response, - fallback, - } => RequestResponseEvent::ResponseReceived { peer, request_id, response, fallback }, - InnerRequestResponseEvent::RequestFailed { peer, request_id, error } => - RequestResponseEvent::RequestFailed { peer, request_id, error }, - _ => panic!("unhandled event"), - } - } + fn from(event: InnerRequestResponseEvent) -> Self { + match event { + InnerRequestResponseEvent::ResponseReceived { + peer, + request_id, + response, + fallback, + } => RequestResponseEvent::ResponseReceived { + peer, + request_id, + response, + fallback, + }, + InnerRequestResponseEvent::RequestFailed { + peer, + request_id, + error, + } => RequestResponseEvent::RequestFailed { + peer, + request_id, + error, + }, + _ => panic!("unhandled event"), + } + } } /// Request-response events. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RequestResponseEvent { - /// Request received from remote - RequestReceived { - /// Peer Id. - peer: PeerId, - - /// Fallback protocol, if the substream was negotiated using a fallback. - fallback: Option, - - /// Request ID. - /// - /// While `request_id` is guaranteed to be unique for this protocols, the request IDs are - /// not unique across different request-response protocols, meaning two different - /// request-response protocols can both assign `RequestId(123)` for any given request. - request_id: RequestId, - - /// Received request. - request: Vec, - }, - - /// Response received. - ResponseReceived { - /// Peer Id. - peer: PeerId, - - /// Request ID. - request_id: RequestId, - - /// Fallback protocol, if the substream was negotiated using a fallback. - fallback: Option, - - /// Received request. - response: Vec, - }, - - /// Request failed. - RequestFailed { - /// Peer Id. - peer: PeerId, - - /// Request ID. - request_id: RequestId, - - /// Request-response error. - error: RequestResponseError, - }, + /// Request received from remote + RequestReceived { + /// Peer Id. + peer: PeerId, + + /// Fallback protocol, if the substream was negotiated using a fallback. + fallback: Option, + + /// Request ID. + /// + /// While `request_id` is guaranteed to be unique for this protocols, the request IDs are + /// not unique across different request-response protocols, meaning two different + /// request-response protocols can both assign `RequestId(123)` for any given request. + request_id: RequestId, + + /// Received request. + request: Vec, + }, + + /// Response received. + ResponseReceived { + /// Peer Id. + peer: PeerId, + + /// Request ID. + request_id: RequestId, + + /// Fallback protocol, if the substream was negotiated using a fallback. + fallback: Option, + + /// Received request. + response: Vec, + }, + + /// Request failed. + RequestFailed { + /// Peer Id. + peer: PeerId, + + /// Request ID. + request_id: RequestId, + + /// Request-response error. + error: RequestResponseError, + }, } /// Dial behavior when sending requests. #[derive(Debug)] pub enum DialOptions { - /// If the peer is not currently connected, attempt to dial them before sending a request. - /// - /// If the dial succeeds, the request is sent to the peer once the peer has been registered - /// to the protocol. - /// - /// If the dial fails, [`RequestResponseError::Rejected`] is returned. - Dial, - - /// If the peer is not connected, immediately reject the request and return - /// [`RequestResponseError::NotConnected`]. - Reject, + /// If the peer is not currently connected, attempt to dial them before sending a request. + /// + /// If the dial succeeds, the request is sent to the peer once the peer has been registered + /// to the protocol. + /// + /// If the dial fails, [`RequestResponseError::Rejected`] is returned. + Dial, + + /// If the peer is not connected, immediately reject the request and return + /// [`RequestResponseError::NotConnected`]. + Reject, } /// Request-response commands. pub(crate) enum RequestResponseCommand { - /// Send request to remote peer. - SendRequest { - /// Peer ID. - peer: PeerId, - - /// Request ID. - /// - /// When a response is received or the request fails, the event contains this ID that - /// the user protocol can associate with the correct request. - /// - /// If the user protocol only has one active request per peer, this ID can be safely - /// discarded. - request_id: RequestId, - - /// Request. - request: Vec, - - /// Dial options, see [`DialOptions`] for more details. - dial_options: DialOptions, - }, - - SendRequestWithFallback { - /// Peer ID. - peer: PeerId, - - /// Request ID. - request_id: RequestId, - - /// Request that is sent over the main protocol, if negotiated. - request: Vec, - - /// Request that is sent over the fallback protocol, if negotiated. - fallback: (ProtocolName, Vec), - - /// Dial options, see [`DialOptions`] for more details. - dial_options: DialOptions, - }, - - /// Cancel outbound request. - CancelRequest { - /// Request ID. - request_id: RequestId, - }, + /// Send request to remote peer. + SendRequest { + /// Peer ID. + peer: PeerId, + + /// Request ID. + /// + /// When a response is received or the request fails, the event contains this ID that + /// the user protocol can associate with the correct request. + /// + /// If the user protocol only has one active request per peer, this ID can be safely + /// discarded. + request_id: RequestId, + + /// Request. + request: Vec, + + /// Dial options, see [`DialOptions`] for more details. + dial_options: DialOptions, + }, + + SendRequestWithFallback { + /// Peer ID. + peer: PeerId, + + /// Request ID. + request_id: RequestId, + + /// Request that is sent over the main protocol, if negotiated. + request: Vec, + + /// Request that is sent over the fallback protocol, if negotiated. + fallback: (ProtocolName, Vec), + + /// Dial options, see [`DialOptions`] for more details. + dial_options: DialOptions, + }, + + /// Cancel outbound request. + CancelRequest { + /// Request ID. + request_id: RequestId, + }, } /// Handle given to the user protocol which allows it to interact with the request-response /// protocol. pub struct RequestResponseHandle { - /// TX channel for sending commands to the request-response protocol. - event_rx: Receiver, + /// TX channel for sending commands to the request-response protocol. + event_rx: Receiver, - /// RX channel for receiving events from the request-response protocol. - command_tx: Sender, + /// RX channel for receiving events from the request-response protocol. + command_tx: Sender, - /// Pending responses. - pending_responses: - HashMap, Option>)>>, + /// Pending responses. + pending_responses: + HashMap, Option>)>>, - /// Next ephemeral request ID. - next_request_id: Arc, + /// Next ephemeral request ID. + next_request_id: Arc, } impl RequestResponseHandle { - /// Create new [`RequestResponseHandle`]. - pub(super) fn new( - event_rx: Receiver, - command_tx: Sender, - next_request_id: Arc, - ) -> Self { - Self { event_rx, command_tx, next_request_id, pending_responses: HashMap::new() } - } - - /// Reject an inbound request. - /// - /// Reject request received from a remote peer. The substream is dropped which signals - /// to the remote peer that request was rejected. - pub fn reject_request(&mut self, request_id: RequestId) { - match self.pending_responses.remove(&request_id) { - None => { - tracing::debug!(target: LOG_TARGET, ?request_id, "rejected request doesn't exist") - }, - Some(sender) => { - tracing::debug!(target: LOG_TARGET, ?request_id, "reject request"); - drop(sender); - }, - } - } - - /// Cancel an outbound request. - /// - /// Allows canceling an in-flight request if the local node is not interested in the answer - /// anymore. If the request was canceled, no event is reported to the user as the cancelation - /// always succeeds and it's assumed that the user does the necessary state clean up in their - /// end after calling [`RequestResponseHandle::cancel_request()`]. - pub async fn cancel_request(&mut self, request_id: RequestId) { - tracing::trace!(target: LOG_TARGET, ?request_id, "cancel request"); - - let _ = self.command_tx.send(RequestResponseCommand::CancelRequest { request_id }).await; - } - - /// Get next request ID. - fn next_request_id(&self) -> RequestId { - let request_id = self.next_request_id.fetch_add(1usize, Ordering::Relaxed); - RequestId::from(request_id) - } - - /// Send request to remote peer. - /// - /// While the returned `RequestId` is guaranteed to be unique for this request-response - /// protocol, it's not unique across all installed request-response protocols. That is, - /// multiple request-response protocols can return the same `RequestId` and this must be - /// handled by the calling code correctly if the `RequestId`s are stored somewhere. - pub async fn send_request( - &mut self, - peer: PeerId, - request: Vec, - dial_options: DialOptions, - ) -> crate::Result { - tracing::trace!(target: LOG_TARGET, ?peer, "send request to peer"); - - let request_id = self.next_request_id(); - self.command_tx - .send(RequestResponseCommand::SendRequest { peer, request_id, request, dial_options }) - .await - .map(|_| request_id) - .map_err(From::from) - } - - /// Attempt to send request to peer and if the channel is clogged, return - /// `Error::ChannelClogged`. - /// - /// While the returned `RequestId` is guaranteed to be unique for this request-response - /// protocol, it's not unique across all installed request-response protocols. That is, - /// multiple request-response protocols can return the same `RequestId` and this must be - /// handled by the calling code correctly if the `RequestId`s are stored somewhere. - pub fn try_send_request( - &mut self, - peer: PeerId, - request: Vec, - dial_options: DialOptions, - ) -> crate::Result { - tracing::trace!(target: LOG_TARGET, ?peer, "send request to peer"); - - let request_id = self.next_request_id(); - self.command_tx - .try_send(RequestResponseCommand::SendRequest { - peer, - request_id, - request, - dial_options, - }) - .map(|_| request_id) - .map_err(|_| Error::ChannelClogged) - } - - /// Send request to remote peer with fallback. - pub async fn send_request_with_fallback( - &mut self, - peer: PeerId, - request: Vec, - fallback: (ProtocolName, Vec), - dial_options: DialOptions, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?peer, - fallback = %fallback.0, - ?dial_options, - "send request with fallback to peer", - ); - - let request_id = self.next_request_id(); - self.command_tx - .send(RequestResponseCommand::SendRequestWithFallback { - peer, - request_id, - fallback, - request, - dial_options, - }) - .await - .map(|_| request_id) - .map_err(From::from) - } - - /// Attempt to send request to peer with fallback and if the channel is clogged, - /// return `Error::ChannelClogged`. - pub fn try_send_request_with_fallback( - &mut self, - peer: PeerId, - request: Vec, - fallback: (ProtocolName, Vec), - dial_options: DialOptions, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?peer, - fallback = %fallback.0, - ?dial_options, - "send request with fallback to peer", - ); - - let request_id = self.next_request_id(); - self.command_tx - .try_send(RequestResponseCommand::SendRequestWithFallback { - peer, - request_id, - fallback, - request, - dial_options, - }) - .map(|_| request_id) - .map_err(|_| Error::ChannelClogged) - } - - /// Send response to remote peer. - pub fn send_response(&mut self, request_id: RequestId, response: Vec) { - match self.pending_responses.remove(&request_id) { - None => { - tracing::debug!(target: LOG_TARGET, ?request_id, "pending response doens't exist"); - }, - Some(response_tx) => { - tracing::trace!(target: LOG_TARGET, ?request_id, "send response to peer"); - - if let Err(_) = response_tx.send((response, None)) { - tracing::debug!(target: LOG_TARGET, ?request_id, "substream closed"); - } - }, - } - } - - /// Send response to remote peer with feedback. - /// - /// The feedback system is inherited from Polkadot SDK's `sc-network` and it's used to notify - /// the sender of the response whether it was sent successfully or not. Once the response has - /// been sent over the substream successfully, `()` will be sent over the feedback channel - /// to the sender to notify them about it. If the substream has been closed or the substream - /// failed while sending the response, the feedback channel will be dropped, notifying the - /// sender that sending the response failed. - pub fn send_response_with_feedback( - &mut self, - request_id: RequestId, - response: Vec, - feedback: channel::oneshot::Sender<()>, - ) { - match self.pending_responses.remove(&request_id) { - None => { - tracing::debug!(target: LOG_TARGET, ?request_id, "pending response doens't exist"); - }, - Some(response_tx) => { - tracing::trace!(target: LOG_TARGET, ?request_id, "send response to peer"); - - if let Err(_) = response_tx.send((response, Some(feedback))) { - tracing::debug!(target: LOG_TARGET, ?request_id, "substream closed"); - } - }, - } - } + /// Create new [`RequestResponseHandle`]. + pub(super) fn new( + event_rx: Receiver, + command_tx: Sender, + next_request_id: Arc, + ) -> Self { + Self { + event_rx, + command_tx, + next_request_id, + pending_responses: HashMap::new(), + } + } + + /// Reject an inbound request. + /// + /// Reject request received from a remote peer. The substream is dropped which signals + /// to the remote peer that request was rejected. + pub fn reject_request(&mut self, request_id: RequestId) { + match self.pending_responses.remove(&request_id) { + None => { + tracing::debug!(target: LOG_TARGET, ?request_id, "rejected request doesn't exist") + } + Some(sender) => { + tracing::debug!(target: LOG_TARGET, ?request_id, "reject request"); + drop(sender); + } + } + } + + /// Cancel an outbound request. + /// + /// Allows canceling an in-flight request if the local node is not interested in the answer + /// anymore. If the request was canceled, no event is reported to the user as the cancelation + /// always succeeds and it's assumed that the user does the necessary state clean up in their + /// end after calling [`RequestResponseHandle::cancel_request()`]. + pub async fn cancel_request(&mut self, request_id: RequestId) { + tracing::trace!(target: LOG_TARGET, ?request_id, "cancel request"); + + let _ = self.command_tx.send(RequestResponseCommand::CancelRequest { request_id }).await; + } + + /// Get next request ID. + fn next_request_id(&self) -> RequestId { + let request_id = self.next_request_id.fetch_add(1usize, Ordering::Relaxed); + RequestId::from(request_id) + } + + /// Send request to remote peer. + /// + /// While the returned `RequestId` is guaranteed to be unique for this request-response + /// protocol, it's not unique across all installed request-response protocols. That is, + /// multiple request-response protocols can return the same `RequestId` and this must be + /// handled by the calling code correctly if the `RequestId`s are stored somewhere. + pub async fn send_request( + &mut self, + peer: PeerId, + request: Vec, + dial_options: DialOptions, + ) -> crate::Result { + tracing::trace!(target: LOG_TARGET, ?peer, "send request to peer"); + + let request_id = self.next_request_id(); + self.command_tx + .send(RequestResponseCommand::SendRequest { + peer, + request_id, + request, + dial_options, + }) + .await + .map(|_| request_id) + .map_err(From::from) + } + + /// Attempt to send request to peer and if the channel is clogged, return + /// `Error::ChannelClogged`. + /// + /// While the returned `RequestId` is guaranteed to be unique for this request-response + /// protocol, it's not unique across all installed request-response protocols. That is, + /// multiple request-response protocols can return the same `RequestId` and this must be + /// handled by the calling code correctly if the `RequestId`s are stored somewhere. + pub fn try_send_request( + &mut self, + peer: PeerId, + request: Vec, + dial_options: DialOptions, + ) -> crate::Result { + tracing::trace!(target: LOG_TARGET, ?peer, "send request to peer"); + + let request_id = self.next_request_id(); + self.command_tx + .try_send(RequestResponseCommand::SendRequest { + peer, + request_id, + request, + dial_options, + }) + .map(|_| request_id) + .map_err(|_| Error::ChannelClogged) + } + + /// Send request to remote peer with fallback. + pub async fn send_request_with_fallback( + &mut self, + peer: PeerId, + request: Vec, + fallback: (ProtocolName, Vec), + dial_options: DialOptions, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?peer, + fallback = %fallback.0, + ?dial_options, + "send request with fallback to peer", + ); + + let request_id = self.next_request_id(); + self.command_tx + .send(RequestResponseCommand::SendRequestWithFallback { + peer, + request_id, + fallback, + request, + dial_options, + }) + .await + .map(|_| request_id) + .map_err(From::from) + } + + /// Attempt to send request to peer with fallback and if the channel is clogged, + /// return `Error::ChannelClogged`. + pub fn try_send_request_with_fallback( + &mut self, + peer: PeerId, + request: Vec, + fallback: (ProtocolName, Vec), + dial_options: DialOptions, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?peer, + fallback = %fallback.0, + ?dial_options, + "send request with fallback to peer", + ); + + let request_id = self.next_request_id(); + self.command_tx + .try_send(RequestResponseCommand::SendRequestWithFallback { + peer, + request_id, + fallback, + request, + dial_options, + }) + .map(|_| request_id) + .map_err(|_| Error::ChannelClogged) + } + + /// Send response to remote peer. + pub fn send_response(&mut self, request_id: RequestId, response: Vec) { + match self.pending_responses.remove(&request_id) { + None => { + tracing::debug!(target: LOG_TARGET, ?request_id, "pending response doens't exist"); + } + Some(response_tx) => { + tracing::trace!(target: LOG_TARGET, ?request_id, "send response to peer"); + + if let Err(_) = response_tx.send((response, None)) { + tracing::debug!(target: LOG_TARGET, ?request_id, "substream closed"); + } + } + } + } + + /// Send response to remote peer with feedback. + /// + /// The feedback system is inherited from Polkadot SDK's `sc-network` and it's used to notify + /// the sender of the response whether it was sent successfully or not. Once the response has + /// been sent over the substream successfully, `()` will be sent over the feedback channel + /// to the sender to notify them about it. If the substream has been closed or the substream + /// failed while sending the response, the feedback channel will be dropped, notifying the + /// sender that sending the response failed. + pub fn send_response_with_feedback( + &mut self, + request_id: RequestId, + response: Vec, + feedback: channel::oneshot::Sender<()>, + ) { + match self.pending_responses.remove(&request_id) { + None => { + tracing::debug!(target: LOG_TARGET, ?request_id, "pending response doens't exist"); + } + Some(response_tx) => { + tracing::trace!(target: LOG_TARGET, ?request_id, "send response to peer"); + + if let Err(_) = response_tx.send((response, Some(feedback))) { + tracing::debug!(target: LOG_TARGET, ?request_id, "substream closed"); + } + } + } + } } impl futures::Stream for RequestResponseHandle { - type Item = RequestResponseEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match futures::ready!(self.event_rx.poll_recv(cx)) { - None => return Poll::Ready(None), - Some(event) => match event { - InnerRequestResponseEvent::RequestReceived { - peer, - fallback, - request_id, - request, - response_tx, - } => { - self.pending_responses.insert(request_id, response_tx); - Poll::Ready(Some(RequestResponseEvent::RequestReceived { - peer, - fallback, - request_id, - request, - })) - }, - event => Poll::Ready(Some(event.into())), - }, - } - } + type Item = RequestResponseEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match futures::ready!(self.event_rx.poll_recv(cx)) { + None => return Poll::Ready(None), + Some(event) => match event { + InnerRequestResponseEvent::RequestReceived { + peer, + fallback, + request_id, + request, + response_tx, + } => { + self.pending_responses.insert(request_id, response_tx); + Poll::Ready(Some(RequestResponseEvent::RequestReceived { + peer, + fallback, + request_id, + request, + })) + } + event => Poll::Ready(Some(event.into())), + }, + } + } } diff --git a/src/protocol/request_response/mod.rs b/src/protocol/request_response/mod.rs index 65e1ad67..16b3c468 100644 --- a/src/protocol/request_response/mod.rs +++ b/src/protocol/request_response/mod.rs @@ -21,35 +21,35 @@ //! Request-response protocol implementation. use crate::{ - error::{Error, NegotiationError}, - multistream_select::NegotiationError::Failed as MultistreamFailed, - protocol::{ - request_response::handle::{InnerRequestResponseEvent, RequestResponseCommand}, - Direction, TransportEvent, TransportService, - }, - substream::{Substream, SubstreamSet}, - types::{protocol::ProtocolName, RequestId, SubstreamId}, - PeerId, + error::{Error, NegotiationError}, + multistream_select::NegotiationError::Failed as MultistreamFailed, + protocol::{ + request_response::handle::{InnerRequestResponseEvent, RequestResponseCommand}, + Direction, TransportEvent, TransportService, + }, + substream::{Substream, SubstreamSet}, + types::{protocol::ProtocolName, RequestId, SubstreamId}, + PeerId, }; use bytes::BytesMut; use futures::{channel, future::BoxFuture, stream::FuturesUnordered, StreamExt}; use tokio::{ - sync::{ - mpsc::{Receiver, Sender}, - oneshot, - }, - time::sleep, + sync::{ + mpsc::{Receiver, Sender}, + oneshot, + }, + time::sleep, }; use std::{ - collections::{hash_map::Entry, HashMap, HashSet}, - io::ErrorKind, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - time::Duration, + collections::{hash_map::Entry, HashMap, HashSet}, + io::ErrorKind, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, }; pub use config::{Config, ConfigBuilder}; @@ -71,886 +71,928 @@ const LOG_TARGET: &str = "litep2p::request-response::protocol"; const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); /// Pending request. -type PendingRequest = - (PeerId, RequestId, Option, Result, RequestResponseError>); +type PendingRequest = ( + PeerId, + RequestId, + Option, + Result, RequestResponseError>, +); /// Request context. struct RequestContext { - /// Peer ID. - peer: PeerId, + /// Peer ID. + peer: PeerId, - /// Request ID. - request_id: RequestId, + /// Request ID. + request_id: RequestId, - /// Request. - request: Vec, + /// Request. + request: Vec, - /// Fallback request. - fallback: Option<(ProtocolName, Vec)>, + /// Fallback request. + fallback: Option<(ProtocolName, Vec)>, } impl RequestContext { - /// Create new [`RequestContext`]. - fn new( - peer: PeerId, - request_id: RequestId, - request: Vec, - fallback: Option<(ProtocolName, Vec)>, - ) -> Self { - Self { peer, request_id, request, fallback } - } + /// Create new [`RequestContext`]. + fn new( + peer: PeerId, + request_id: RequestId, + request: Vec, + fallback: Option<(ProtocolName, Vec)>, + ) -> Self { + Self { + peer, + request_id, + request, + fallback, + } + } } /// Peer context. struct PeerContext { - /// Active requests. - active: HashSet, + /// Active requests. + active: HashSet, - /// Active inbound requests and their fallback names. - active_inbound: HashMap>, + /// Active inbound requests and their fallback names. + active_inbound: HashMap>, } impl PeerContext { - /// Create new [`PeerContext`]. - fn new() -> Self { - Self { active: HashSet::new(), active_inbound: HashMap::new() } - } + /// Create new [`PeerContext`]. + fn new() -> Self { + Self { + active: HashSet::new(), + active_inbound: HashMap::new(), + } + } } /// Request-response protocol. pub(crate) struct RequestResponseProtocol { - /// Transport service. - service: TransportService, + /// Transport service. + service: TransportService, - /// Protocol. - protocol: ProtocolName, + /// Protocol. + protocol: ProtocolName, - /// Connected peers. - peers: HashMap, + /// Connected peers. + peers: HashMap, - /// Pending outbound substreams, mapped from `SubstreamId` to `RequestId`. - pending_outbound: HashMap, + /// Pending outbound substreams, mapped from `SubstreamId` to `RequestId`. + pending_outbound: HashMap, - /// Pending outbound responses. - /// - /// The future listens to a `oneshot::Sender` which is given to `RequestResponseHandle`. - /// If the request is accepted by the local node, the response is sent over the channel to the - /// the future which sends it to remote peer and closes the substream. - /// - /// If the substream is rejected by the local node, the `oneshot::Sender` is dropped which - /// notifies the future that the request should be rejected by closing the substream. - pending_outbound_responses: FuturesUnordered>, + /// Pending outbound responses. + /// + /// The future listens to a `oneshot::Sender` which is given to `RequestResponseHandle`. + /// If the request is accepted by the local node, the response is sent over the channel to the + /// the future which sends it to remote peer and closes the substream. + /// + /// If the substream is rejected by the local node, the `oneshot::Sender` is dropped which + /// notifies the future that the request should be rejected by closing the substream. + pending_outbound_responses: FuturesUnordered>, - /// Pending inbound responses. - pending_inbound: FuturesUnordered>, + /// Pending inbound responses. + pending_inbound: FuturesUnordered>, - /// Pending outbound cancellation handles. - pending_outbound_cancels: HashMap>, + /// Pending outbound cancellation handles. + pending_outbound_cancels: HashMap>, - /// Pending inbound requests. - pending_inbound_requests: SubstreamSet<(PeerId, RequestId), Substream>, + /// Pending inbound requests. + pending_inbound_requests: SubstreamSet<(PeerId, RequestId), Substream>, - /// Pending dials for outbound requests. - pending_dials: HashMap, + /// Pending dials for outbound requests. + pending_dials: HashMap, - /// TX channel for sending events to the user protocol. - event_tx: Sender, + /// TX channel for sending events to the user protocol. + event_tx: Sender, - /// RX channel for receive commands from the `RequestResponseHandle`. - command_rx: Receiver, + /// RX channel for receive commands from the `RequestResponseHandle`. + command_rx: Receiver, - /// Next request ID. - /// - /// Inbound requests are assigned an ephemeral ID TODO: finish - next_request_id: Arc, + /// Next request ID. + /// + /// Inbound requests are assigned an ephemeral ID TODO: finish + next_request_id: Arc, - /// Timeout for outbound requests. - timeout: Duration, + /// Timeout for outbound requests. + timeout: Duration, - /// Maximum concurrent inbound requests, if specified. - max_concurrent_inbound_requests: Option, + /// Maximum concurrent inbound requests, if specified. + max_concurrent_inbound_requests: Option, } impl RequestResponseProtocol { - /// Create new [`RequestResponseProtocol`]. - pub(crate) fn new(service: TransportService, config: Config) -> Self { - Self { - service, - peers: HashMap::new(), - timeout: config.timeout, - next_request_id: config.next_request_id, - event_tx: config.event_tx, - command_rx: config.command_rx, - protocol: config.protocol_name, - pending_dials: HashMap::new(), - pending_outbound: HashMap::new(), - pending_inbound: FuturesUnordered::new(), - pending_outbound_cancels: HashMap::new(), - pending_inbound_requests: SubstreamSet::new(), - pending_outbound_responses: FuturesUnordered::new(), - max_concurrent_inbound_requests: config.max_concurrent_inbound_request, - } - } - - /// Get next ephemeral request ID. - fn next_request_id(&mut self) -> RequestId { - RequestId::from(self.next_request_id.fetch_add(1usize, Ordering::Relaxed)) - } - - /// Connection established to remote peer. - async fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection established"); - - let Entry::Vacant(entry) = self.peers.entry(peer) else { - tracing::error!( - target: LOG_TARGET, - ?peer, - "state mismatch: peer already exists", - ); - debug_assert!(false); - return Err(Error::PeerAlreadyExists(peer)); - }; - - match self.pending_dials.remove(&peer) { - None => { - entry.insert(PeerContext::new()); - }, - Some(context) => match self.service.open_substream(peer) { - Ok(substream_id) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - request_id = ?context.request_id, - ?substream_id, - "dial succeeded, open substream", - ); - - entry.insert(PeerContext { - active: HashSet::from_iter([context.request_id]), - active_inbound: HashMap::new(), - }); - self.pending_outbound.insert( - substream_id, - RequestContext::new( - peer, - context.request_id, - context.request, - context.fallback, - ), - ); - }, - // only reason the substream would fail to open would be that the connection - // would've been reported to the protocol with enough delay that the keep-alive - // timeout had expired and no other protocol had opened a substream to it, causing - // the connection to be closed - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - request_id = ?context.request_id, - ?error, - "failed to open substream", - ); - - return self - .report_request_failure( - peer, - context.request_id, - RequestResponseError::Rejected, - ) - .await; - }, - }, - } - - Ok(()) - } - - /// Connection closed to remote peer. - async fn on_connection_closed(&mut self, peer: PeerId) { - tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection closed"); - - let Some(context) = self.peers.remove(&peer) else { - tracing::error!( - target: LOG_TARGET, - ?peer, - "state mismatch: peer doesn't exist", - ); - debug_assert!(false); - return; - }; - - // sent failure events for all pending outbound requests - for request_id in context.active { - let _ = self - .event_tx - .send(InnerRequestResponseEvent::RequestFailed { - peer, - request_id, - error: RequestResponseError::Rejected, - }) - .await; - } - - // remove all pending inbound requests - for (request_id, _) in context.active_inbound { - self.pending_inbound_requests.remove(&(peer, request_id)); - } - } - - /// Local node opened a substream to remote node. - async fn on_outbound_substream( - &mut self, - peer: PeerId, - substream_id: SubstreamId, - mut substream: Substream, - fallback_protocol: Option, - ) -> crate::Result<()> { - let Some(RequestContext { request_id, request, fallback, .. }) = - self.pending_outbound.remove(&substream_id) - else { - tracing::error!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?substream_id, - "pending outbound request does not exist", - ); - debug_assert!(false); - - return Err(Error::InvalidState); - }; - - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?substream_id, - ?request_id, - "substream opened, send request", - ); - - let request = match (&fallback_protocol, fallback) { - (Some(protocol), Some((fallback_protocol, fallback_request))) - if protocol == &fallback_protocol => - fallback_request, - _ => request, - }; - - let request_timeout = self.timeout; - let protocol = self.protocol.clone(); - let (tx, rx) = oneshot::channel(); - self.pending_outbound_cancels.insert(request_id, tx); - - self.pending_inbound.push(Box::pin(async move { - match tokio::time::timeout(request_timeout, substream.send_framed(request.into())).await - { - Err(_) => (peer, request_id, fallback_protocol, Err(RequestResponseError::Timeout)), - Ok(Err(Error::IoError(ErrorKind::PermissionDenied))) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - %protocol, - "tried to send too large request", - ); - - ( - peer, - request_id, - fallback_protocol, - Err(RequestResponseError::TooLargePayload), - ) - }, - Ok(Err(_error)) => - (peer, request_id, fallback_protocol, Err(RequestResponseError::NotConnected)), - Ok(Ok(_)) => { - tokio::select! { - _ = rx => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - "request canceled", - ); - - let _ = substream.close().await; - ( - peer, - request_id, - fallback_protocol, - Err(RequestResponseError::Canceled)) - } - _ = sleep(request_timeout) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - "request timed out", - ); - - let _ = substream.close().await; - (peer, request_id, fallback_protocol, Err(RequestResponseError::Timeout)) - } - event = substream.next() => match event { - Some(Ok(response)) => { - (peer, request_id, fallback_protocol, Ok(response.freeze().into())) - } - _ => (peer, request_id, fallback_protocol, Err(RequestResponseError::Rejected)), - } - } - }, - } - })); - - Ok(()) - } - - /// Handle pending inbound response. - async fn on_inbound_request( - &mut self, - peer: PeerId, - request_id: RequestId, - request: crate::Result, - ) -> crate::Result<()> { - let fallback = self - .peers - .get_mut(&peer) - .ok_or(Error::PeerDoesntExist(peer))? - .active_inbound - .remove(&request_id) - .ok_or_else(|| { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - "no active inbound request", - ); - - Error::InvalidState - })?; - let mut substream = - self.pending_inbound_requests.remove(&(peer, request_id)).ok_or_else(|| { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - "request doesn't exist in pending requests", - ); - - Error::InvalidState - })?; - let protocol = self.protocol.clone(); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - "inbound request", - ); - - let Ok(request) = request else { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - ?request, - "failed to read request from substream", - ); - return Err(Error::InvalidData); - }; - - // once the request has been read from the substream, start a future which waits - // for an input from the user. - // - // the input is either a response (succes) or rejection (failure) which is communicated - // by sending the response over the `oneshot::Sender` or closing it, respectively. - let timeout = self.timeout; - let (response_tx, rx): ( - oneshot::Sender<(Vec, Option>)>, - _, - ) = oneshot::channel(); - - self.pending_outbound_responses.push(Box::pin(async move { - match rx.await { - Err(_) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - "request rejected", - ); - let _ = substream.close().await; - }, - Ok((response, mut feedback)) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - "send response", - ); - - match tokio::time::timeout(timeout, substream.send_framed(response.into())) - .await - { - Err(_) => tracing::debug!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - "timed out while sending response", - ), - Ok(Ok(_)) => feedback.take().map_or((), |feedback| { - let _ = feedback.send(()); - }), - Ok(Err(error)) => tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?request_id, - ?error, - "failed to send request to peer", - ), - } - }, - } - })); - - self.event_tx - .send(InnerRequestResponseEvent::RequestReceived { - peer, - fallback, - request_id, - request: request.freeze().into(), - response_tx, - }) - .await - .map_err(From::from) - } - - /// Remote opened a substream to local node. - async fn on_inbound_substream( - &mut self, - peer: PeerId, - fallback: Option, - substream: Substream, - ) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "handle inbound substream"); - - if let Some(max_requests) = self.max_concurrent_inbound_requests { - let num_inbound_requests = - self.pending_inbound_requests.len() + self.pending_outbound_responses.len(); - - if max_requests <= num_inbound_requests { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?fallback, - ?max_requests, - "rejecting request as already at maximum", - ); - - let _ = substream.close().await; - return Ok(()); - } - } - - // allocate ephemeral id for the inbound request and return it to the user protocol - // - // when user responds to the request, this is used to associate the response with the - // correct substream. - let request_id = self.next_request_id(); - self.peers - .get_mut(&peer) - .ok_or(Error::PeerDoesntExist(peer))? - .active_inbound - .insert(request_id, fallback); - self.pending_inbound_requests.insert((peer, request_id), substream); - - Ok(()) - } - - async fn on_dial_failure(&mut self, peer: PeerId) { - if let Some(context) = self.pending_dials.remove(&peer) { - tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "failed to dial peer"); - - let _ = self - .peers - .get_mut(&peer) - .map(|peer_context| peer_context.active.remove(&context.request_id)); - let _ = self - .report_request_failure(peer, context.request_id, RequestResponseError::Rejected) - .await; - } - } - - /// Failed to open substream to remote peer. - async fn on_substream_open_failure( - &mut self, - substream: SubstreamId, - error: Error, - ) -> crate::Result<()> { - let Some(RequestContext { request_id, peer, .. }) = - self.pending_outbound.remove(&substream) - else { - tracing::error!( - target: LOG_TARGET, - protocol = %self.protocol, - ?substream, - "pending outbound request does not exist", - ); - debug_assert!(false); - - return Err(Error::InvalidState); - }; - - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?substream, - ?error, - "failed to open substream", - ); - - let _ = self - .peers - .get_mut(&peer) - .map(|peer_context| peer_context.active.remove(&request_id)); - - self.event_tx - .send(InnerRequestResponseEvent::RequestFailed { - peer, - request_id, - error: match error { - Error::NegotiationError(NegotiationError::MultistreamSelectError( - MultistreamFailed, - )) => RequestResponseError::UnsupportedProtocol, - _ => RequestResponseError::Rejected, - }, - }) - .await - .map_err(From::from) - } - - /// Report request send failure to user. - async fn report_request_failure( - &mut self, - peer: PeerId, - request_id: RequestId, - error: RequestResponseError, - ) -> crate::Result<()> { - self.event_tx - .send(InnerRequestResponseEvent::RequestFailed { peer, request_id, error }) - .await - .map_err(From::from) - } - - /// Send request to remote peer. - async fn on_send_request( - &mut self, - peer: PeerId, - request_id: RequestId, - request: Vec, - dial_options: DialOptions, - fallback: Option<(ProtocolName, Vec)>, - ) -> crate::Result<()> { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?dial_options, - "send request to remote peer", - ); - - let Some(context) = self.peers.get_mut(&peer) else { - match dial_options { - DialOptions::Reject => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?dial_options, - "peer not connected and should not dial", - ); - - return self - .report_request_failure( - peer, - request_id, - RequestResponseError::NotConnected, - ) - .await; - }, - DialOptions::Dial => match self.service.dial(&peer) { - Ok(_) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - "started dialing peer", - ); - - self.pending_dials - .insert(peer, RequestContext::new(peer, request_id, request, fallback)); - return Ok(()); - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?error, - "failed to dial peer" - ); - - return self - .report_request_failure( - peer, - request_id, - RequestResponseError::Rejected, - ) - .await; - }, - }, - } - }; - - // open substream and push it pending outbound substreams - // once the substream is opened, send the request. - match self.service.open_substream(peer) { - Ok(substream_id) => { - let unique_request_id = context.active.insert(request_id); - debug_assert!(unique_request_id); - - self.pending_outbound - .insert(substream_id, RequestContext::new(peer, request_id, request, fallback)); - - Ok(()) - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?error, - "failed to open substream", - ); - - self.report_request_failure(peer, request_id, RequestResponseError::Rejected) - .await - }, - } - } - - /// Handle substream event. - async fn on_substream_event( - &mut self, - peer: PeerId, - request_id: RequestId, - fallback: Option, - message: Result, RequestResponseError>, - ) -> crate::Result<()> { - if !self - .peers - .get_mut(&peer) - .ok_or(Error::PeerDoesntExist(peer))? - .active - .remove(&request_id) - { - tracing::warn!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - "invalid state: received substream event but no active substream", - ); - return Err(Error::InvalidState); - } - - let event = match message { - Ok(response) => - InnerRequestResponseEvent::ResponseReceived { peer, request_id, response, fallback }, - Err(error) => match error { - RequestResponseError::Canceled => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - "request canceled by local node", - ); - return Ok(()); - }, - error => InnerRequestResponseEvent::RequestFailed { peer, request_id, error }, - }, - }; - - self.event_tx.send(event).await.map_err(From::from) - } - - /// Cancel outbound request. - async fn on_cancel_request(&mut self, request_id: RequestId) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, protocol = %self.protocol, ?request_id, "cancel outbound request"); - - match self.pending_outbound_cancels.remove(&request_id) { - Some(tx) => tx.send(()).map_err(|_| Error::SubstreamDoesntExist), - None => { - tracing::debug!( - target: LOG_TARGET, - protocol = %self.protocol, - ?request_id, - "tried to cancel request which doesn't exist", - ); - - Ok(()) - }, - } - } - - /// Start [`RequestResponseProtocol`] event loop. - pub async fn run(mut self) { - tracing::debug!(target: LOG_TARGET, "starting request-response event loop"); - - loop { - tokio::select! { - // events coming from the network have higher priority than user commands as all user commands are - // responses to network behaviour so ensure that the commands operate on the most up to date information. - biased; - - event = self.service.next() => match event { - Some(TransportEvent::ConnectionEstablished { peer, .. }) => { - let _ = self.on_connection_established(peer).await; - } - Some(TransportEvent::ConnectionClosed { peer }) => { - self.on_connection_closed(peer).await; - } - Some(TransportEvent::SubstreamOpened { - peer, - substream, - direction, - fallback, - .. - }) => match direction { - Direction::Inbound => { - if let Err(error) = self.on_inbound_substream(peer, fallback, substream).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?error, - "failed to handle inbound substream", - ); - } - } - Direction::Outbound(substream_id) => { - let _ = self.on_outbound_substream(peer, substream_id, substream, fallback).await; - } - }, - Some(TransportEvent::SubstreamOpenFailure { substream, error }) => { - if let Err(error) = self.on_substream_open_failure(substream, error).await { - tracing::warn!( - target: LOG_TARGET, - protocol = %self.protocol, - ?error, - "failed to handle substream open failure", - ); - } - } - Some(TransportEvent::DialFailure { peer, .. }) => self.on_dial_failure(peer).await, - None => return, - }, - event = self.pending_inbound.select_next_some(), if !self.pending_inbound.is_empty() => { - let (peer, request_id, fallback, event) = event; - - if let Err(error) = self.on_substream_event(peer, request_id, fallback, event).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?error, - "failed to handle substream event", - ); - } - - self.pending_outbound_cancels.remove(&request_id); - } - _ = self.pending_outbound_responses.next(), if !self.pending_outbound_responses.is_empty() => {} - event = self.pending_inbound_requests.next() => match event { - Some(((peer, request_id), message)) => { - if let Err(error) = self.on_inbound_request(peer, request_id, message).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?error, - "failed to handle inbound request", - ); - } - } - None => return, - }, - command = self.command_rx.recv() => match command { - None => { - tracing::debug!(target: LOG_TARGET, protocol = %self.protocol, "user protocol has exited, exiting"); - return - } - Some(command) => match command { - RequestResponseCommand::SendRequest { peer, request_id, request, dial_options } => { - if let Err(error) = self.on_send_request(peer, request_id, request, dial_options, None).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?error, - "failed to send request", - ); - } - } - RequestResponseCommand::CancelRequest { request_id } => { - if let Err(error) = self.on_cancel_request(request_id).await { - tracing::debug!( - target: LOG_TARGET, - protocol = %self.protocol, - ?request_id, - ?error, - "failed to cancel reqeuest", - ); - } - } - RequestResponseCommand::SendRequestWithFallback { peer, request_id, request, fallback, dial_options } => { - if let Err(error) = self.on_send_request(peer, request_id, request, dial_options, Some(fallback)).await { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?request_id, - ?error, - "failed to send request", - ); - } - } - } - }, - } - } - } + /// Create new [`RequestResponseProtocol`]. + pub(crate) fn new(service: TransportService, config: Config) -> Self { + Self { + service, + peers: HashMap::new(), + timeout: config.timeout, + next_request_id: config.next_request_id, + event_tx: config.event_tx, + command_rx: config.command_rx, + protocol: config.protocol_name, + pending_dials: HashMap::new(), + pending_outbound: HashMap::new(), + pending_inbound: FuturesUnordered::new(), + pending_outbound_cancels: HashMap::new(), + pending_inbound_requests: SubstreamSet::new(), + pending_outbound_responses: FuturesUnordered::new(), + max_concurrent_inbound_requests: config.max_concurrent_inbound_request, + } + } + + /// Get next ephemeral request ID. + fn next_request_id(&mut self) -> RequestId { + RequestId::from(self.next_request_id.fetch_add(1usize, Ordering::Relaxed)) + } + + /// Connection established to remote peer. + async fn on_connection_established(&mut self, peer: PeerId) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection established"); + + let Entry::Vacant(entry) = self.peers.entry(peer) else { + tracing::error!( + target: LOG_TARGET, + ?peer, + "state mismatch: peer already exists", + ); + debug_assert!(false); + return Err(Error::PeerAlreadyExists(peer)); + }; + + match self.pending_dials.remove(&peer) { + None => { + entry.insert(PeerContext::new()); + } + Some(context) => match self.service.open_substream(peer) { + Ok(substream_id) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + request_id = ?context.request_id, + ?substream_id, + "dial succeeded, open substream", + ); + + entry.insert(PeerContext { + active: HashSet::from_iter([context.request_id]), + active_inbound: HashMap::new(), + }); + self.pending_outbound.insert( + substream_id, + RequestContext::new( + peer, + context.request_id, + context.request, + context.fallback, + ), + ); + } + // only reason the substream would fail to open would be that the connection + // would've been reported to the protocol with enough delay that the keep-alive + // timeout had expired and no other protocol had opened a substream to it, causing + // the connection to be closed + Err(error) => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + request_id = ?context.request_id, + ?error, + "failed to open substream", + ); + + return self + .report_request_failure( + peer, + context.request_id, + RequestResponseError::Rejected, + ) + .await; + } + }, + } + + Ok(()) + } + + /// Connection closed to remote peer. + async fn on_connection_closed(&mut self, peer: PeerId) { + tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "connection closed"); + + let Some(context) = self.peers.remove(&peer) else { + tracing::error!( + target: LOG_TARGET, + ?peer, + "state mismatch: peer doesn't exist", + ); + debug_assert!(false); + return; + }; + + // sent failure events for all pending outbound requests + for request_id in context.active { + let _ = self + .event_tx + .send(InnerRequestResponseEvent::RequestFailed { + peer, + request_id, + error: RequestResponseError::Rejected, + }) + .await; + } + + // remove all pending inbound requests + for (request_id, _) in context.active_inbound { + self.pending_inbound_requests.remove(&(peer, request_id)); + } + } + + /// Local node opened a substream to remote node. + async fn on_outbound_substream( + &mut self, + peer: PeerId, + substream_id: SubstreamId, + mut substream: Substream, + fallback_protocol: Option, + ) -> crate::Result<()> { + let Some(RequestContext { + request_id, + request, + fallback, + .. + }) = self.pending_outbound.remove(&substream_id) + else { + tracing::error!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?substream_id, + "pending outbound request does not exist", + ); + debug_assert!(false); + + return Err(Error::InvalidState); + }; + + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?substream_id, + ?request_id, + "substream opened, send request", + ); + + let request = match (&fallback_protocol, fallback) { + (Some(protocol), Some((fallback_protocol, fallback_request))) + if protocol == &fallback_protocol => + fallback_request, + _ => request, + }; + + let request_timeout = self.timeout; + let protocol = self.protocol.clone(); + let (tx, rx) = oneshot::channel(); + self.pending_outbound_cancels.insert(request_id, tx); + + self.pending_inbound.push(Box::pin(async move { + match tokio::time::timeout(request_timeout, substream.send_framed(request.into())).await + { + Err(_) => ( + peer, + request_id, + fallback_protocol, + Err(RequestResponseError::Timeout), + ), + Ok(Err(Error::IoError(ErrorKind::PermissionDenied))) => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + %protocol, + "tried to send too large request", + ); + + ( + peer, + request_id, + fallback_protocol, + Err(RequestResponseError::TooLargePayload), + ) + } + Ok(Err(_error)) => ( + peer, + request_id, + fallback_protocol, + Err(RequestResponseError::NotConnected), + ), + Ok(Ok(_)) => { + tokio::select! { + _ = rx => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + "request canceled", + ); + + let _ = substream.close().await; + ( + peer, + request_id, + fallback_protocol, + Err(RequestResponseError::Canceled)) + } + _ = sleep(request_timeout) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + "request timed out", + ); + + let _ = substream.close().await; + (peer, request_id, fallback_protocol, Err(RequestResponseError::Timeout)) + } + event = substream.next() => match event { + Some(Ok(response)) => { + (peer, request_id, fallback_protocol, Ok(response.freeze().into())) + } + _ => (peer, request_id, fallback_protocol, Err(RequestResponseError::Rejected)), + } + } + } + } + })); + + Ok(()) + } + + /// Handle pending inbound response. + async fn on_inbound_request( + &mut self, + peer: PeerId, + request_id: RequestId, + request: crate::Result, + ) -> crate::Result<()> { + let fallback = self + .peers + .get_mut(&peer) + .ok_or(Error::PeerDoesntExist(peer))? + .active_inbound + .remove(&request_id) + .ok_or_else(|| { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + "no active inbound request", + ); + + Error::InvalidState + })?; + let mut substream = + self.pending_inbound_requests.remove(&(peer, request_id)).ok_or_else(|| { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + "request doesn't exist in pending requests", + ); + + Error::InvalidState + })?; + let protocol = self.protocol.clone(); + + tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + "inbound request", + ); + + let Ok(request) = request else { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + ?request, + "failed to read request from substream", + ); + return Err(Error::InvalidData); + }; + + // once the request has been read from the substream, start a future which waits + // for an input from the user. + // + // the input is either a response (succes) or rejection (failure) which is communicated + // by sending the response over the `oneshot::Sender` or closing it, respectively. + let timeout = self.timeout; + let (response_tx, rx): ( + oneshot::Sender<(Vec, Option>)>, + _, + ) = oneshot::channel(); + + self.pending_outbound_responses.push(Box::pin(async move { + match rx.await { + Err(_) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + "request rejected", + ); + let _ = substream.close().await; + } + Ok((response, mut feedback)) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + "send response", + ); + + match tokio::time::timeout(timeout, substream.send_framed(response.into())) + .await + { + Err(_) => tracing::debug!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + "timed out while sending response", + ), + Ok(Ok(_)) => feedback.take().map_or((), |feedback| { + let _ = feedback.send(()); + }), + Ok(Err(error)) => tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?request_id, + ?error, + "failed to send request to peer", + ), + } + } + } + })); + + self.event_tx + .send(InnerRequestResponseEvent::RequestReceived { + peer, + fallback, + request_id, + request: request.freeze().into(), + response_tx, + }) + .await + .map_err(From::from) + } + + /// Remote opened a substream to local node. + async fn on_inbound_substream( + &mut self, + peer: PeerId, + fallback: Option, + substream: Substream, + ) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "handle inbound substream"); + + if let Some(max_requests) = self.max_concurrent_inbound_requests { + let num_inbound_requests = + self.pending_inbound_requests.len() + self.pending_outbound_responses.len(); + + if max_requests <= num_inbound_requests { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?fallback, + ?max_requests, + "rejecting request as already at maximum", + ); + + let _ = substream.close().await; + return Ok(()); + } + } + + // allocate ephemeral id for the inbound request and return it to the user protocol + // + // when user responds to the request, this is used to associate the response with the + // correct substream. + let request_id = self.next_request_id(); + self.peers + .get_mut(&peer) + .ok_or(Error::PeerDoesntExist(peer))? + .active_inbound + .insert(request_id, fallback); + self.pending_inbound_requests.insert((peer, request_id), substream); + + Ok(()) + } + + async fn on_dial_failure(&mut self, peer: PeerId) { + if let Some(context) = self.pending_dials.remove(&peer) { + tracing::debug!(target: LOG_TARGET, ?peer, protocol = %self.protocol, "failed to dial peer"); + + let _ = self + .peers + .get_mut(&peer) + .map(|peer_context| peer_context.active.remove(&context.request_id)); + let _ = self + .report_request_failure(peer, context.request_id, RequestResponseError::Rejected) + .await; + } + } + + /// Failed to open substream to remote peer. + async fn on_substream_open_failure( + &mut self, + substream: SubstreamId, + error: Error, + ) -> crate::Result<()> { + let Some(RequestContext { + request_id, peer, .. + }) = self.pending_outbound.remove(&substream) + else { + tracing::error!( + target: LOG_TARGET, + protocol = %self.protocol, + ?substream, + "pending outbound request does not exist", + ); + debug_assert!(false); + + return Err(Error::InvalidState); + }; + + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?substream, + ?error, + "failed to open substream", + ); + + let _ = self + .peers + .get_mut(&peer) + .map(|peer_context| peer_context.active.remove(&request_id)); + + self.event_tx + .send(InnerRequestResponseEvent::RequestFailed { + peer, + request_id, + error: match error { + Error::NegotiationError(NegotiationError::MultistreamSelectError( + MultistreamFailed, + )) => RequestResponseError::UnsupportedProtocol, + _ => RequestResponseError::Rejected, + }, + }) + .await + .map_err(From::from) + } + + /// Report request send failure to user. + async fn report_request_failure( + &mut self, + peer: PeerId, + request_id: RequestId, + error: RequestResponseError, + ) -> crate::Result<()> { + self.event_tx + .send(InnerRequestResponseEvent::RequestFailed { + peer, + request_id, + error, + }) + .await + .map_err(From::from) + } + + /// Send request to remote peer. + async fn on_send_request( + &mut self, + peer: PeerId, + request_id: RequestId, + request: Vec, + dial_options: DialOptions, + fallback: Option<(ProtocolName, Vec)>, + ) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?dial_options, + "send request to remote peer", + ); + + let Some(context) = self.peers.get_mut(&peer) else { + match dial_options { + DialOptions::Reject => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?dial_options, + "peer not connected and should not dial", + ); + + return self + .report_request_failure( + peer, + request_id, + RequestResponseError::NotConnected, + ) + .await; + } + DialOptions::Dial => match self.service.dial(&peer) { + Ok(_) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + "started dialing peer", + ); + + self.pending_dials.insert( + peer, + RequestContext::new(peer, request_id, request, fallback), + ); + return Ok(()); + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?error, + "failed to dial peer" + ); + + return self + .report_request_failure( + peer, + request_id, + RequestResponseError::Rejected, + ) + .await; + } + }, + } + }; + + // open substream and push it pending outbound substreams + // once the substream is opened, send the request. + match self.service.open_substream(peer) { + Ok(substream_id) => { + let unique_request_id = context.active.insert(request_id); + debug_assert!(unique_request_id); + + self.pending_outbound.insert( + substream_id, + RequestContext::new(peer, request_id, request, fallback), + ); + + Ok(()) + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?error, + "failed to open substream", + ); + + self.report_request_failure(peer, request_id, RequestResponseError::Rejected) + .await + } + } + } + + /// Handle substream event. + async fn on_substream_event( + &mut self, + peer: PeerId, + request_id: RequestId, + fallback: Option, + message: Result, RequestResponseError>, + ) -> crate::Result<()> { + if !self + .peers + .get_mut(&peer) + .ok_or(Error::PeerDoesntExist(peer))? + .active + .remove(&request_id) + { + tracing::warn!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + "invalid state: received substream event but no active substream", + ); + return Err(Error::InvalidState); + } + + let event = match message { + Ok(response) => InnerRequestResponseEvent::ResponseReceived { + peer, + request_id, + response, + fallback, + }, + Err(error) => match error { + RequestResponseError::Canceled => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + "request canceled by local node", + ); + return Ok(()); + } + error => InnerRequestResponseEvent::RequestFailed { + peer, + request_id, + error, + }, + }, + }; + + self.event_tx.send(event).await.map_err(From::from) + } + + /// Cancel outbound request. + async fn on_cancel_request(&mut self, request_id: RequestId) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, protocol = %self.protocol, ?request_id, "cancel outbound request"); + + match self.pending_outbound_cancels.remove(&request_id) { + Some(tx) => tx.send(()).map_err(|_| Error::SubstreamDoesntExist), + None => { + tracing::debug!( + target: LOG_TARGET, + protocol = %self.protocol, + ?request_id, + "tried to cancel request which doesn't exist", + ); + + Ok(()) + } + } + } + + /// Start [`RequestResponseProtocol`] event loop. + pub async fn run(mut self) { + tracing::debug!(target: LOG_TARGET, "starting request-response event loop"); + + loop { + tokio::select! { + // events coming from the network have higher priority than user commands as all user commands are + // responses to network behaviour so ensure that the commands operate on the most up to date information. + biased; + + event = self.service.next() => match event { + Some(TransportEvent::ConnectionEstablished { peer, .. }) => { + let _ = self.on_connection_established(peer).await; + } + Some(TransportEvent::ConnectionClosed { peer }) => { + self.on_connection_closed(peer).await; + } + Some(TransportEvent::SubstreamOpened { + peer, + substream, + direction, + fallback, + .. + }) => match direction { + Direction::Inbound => { + if let Err(error) = self.on_inbound_substream(peer, fallback, substream).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?error, + "failed to handle inbound substream", + ); + } + } + Direction::Outbound(substream_id) => { + let _ = self.on_outbound_substream(peer, substream_id, substream, fallback).await; + } + }, + Some(TransportEvent::SubstreamOpenFailure { substream, error }) => { + if let Err(error) = self.on_substream_open_failure(substream, error).await { + tracing::warn!( + target: LOG_TARGET, + protocol = %self.protocol, + ?error, + "failed to handle substream open failure", + ); + } + } + Some(TransportEvent::DialFailure { peer, .. }) => self.on_dial_failure(peer).await, + None => return, + }, + event = self.pending_inbound.select_next_some(), if !self.pending_inbound.is_empty() => { + let (peer, request_id, fallback, event) = event; + + if let Err(error) = self.on_substream_event(peer, request_id, fallback, event).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?error, + "failed to handle substream event", + ); + } + + self.pending_outbound_cancels.remove(&request_id); + } + _ = self.pending_outbound_responses.next(), if !self.pending_outbound_responses.is_empty() => {} + event = self.pending_inbound_requests.next() => match event { + Some(((peer, request_id), message)) => { + if let Err(error) = self.on_inbound_request(peer, request_id, message).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?error, + "failed to handle inbound request", + ); + } + } + None => return, + }, + command = self.command_rx.recv() => match command { + None => { + tracing::debug!(target: LOG_TARGET, protocol = %self.protocol, "user protocol has exited, exiting"); + return + } + Some(command) => match command { + RequestResponseCommand::SendRequest { peer, request_id, request, dial_options } => { + if let Err(error) = self.on_send_request(peer, request_id, request, dial_options, None).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?error, + "failed to send request", + ); + } + } + RequestResponseCommand::CancelRequest { request_id } => { + if let Err(error) = self.on_cancel_request(request_id).await { + tracing::debug!( + target: LOG_TARGET, + protocol = %self.protocol, + ?request_id, + ?error, + "failed to cancel reqeuest", + ); + } + } + RequestResponseCommand::SendRequestWithFallback { peer, request_id, request, fallback, dial_options } => { + if let Err(error) = self.on_send_request(peer, request_id, request, dial_options, Some(fallback)).await { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?request_id, + ?error, + "failed to send request", + ); + } + } + } + }, + } + } + } } diff --git a/src/protocol/request_response/tests.rs b/src/protocol/request_response/tests.rs index 6d4e5487..524b3b2d 100644 --- a/src/protocol/request_response/tests.rs +++ b/src/protocol/request_response/tests.rs @@ -19,19 +19,19 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - crypto::ed25519::Keypair, - mock::substream::{DummySubstream, MockSubstream}, - protocol::{ - request_response::{ - ConfigBuilder, DialOptions, RequestResponseError, RequestResponseEvent, - RequestResponseHandle, RequestResponseProtocol, - }, - InnerTransportEvent, TransportService, - }, - substream::Substream, - transport::manager::TransportManager, - types::{RequestId, SubstreamId}, - BandwidthSink, Error, PeerId, ProtocolName, + crypto::ed25519::Keypair, + mock::substream::{DummySubstream, MockSubstream}, + protocol::{ + request_response::{ + ConfigBuilder, DialOptions, RequestResponseError, RequestResponseEvent, + RequestResponseHandle, RequestResponseProtocol, + }, + InnerTransportEvent, TransportService, + }, + substream::Substream, + transport::manager::TransportManager, + types::{RequestId, SubstreamId}, + BandwidthSink, Error, PeerId, ProtocolName, }; use futures::StreamExt; @@ -40,196 +40,216 @@ use tokio::sync::mpsc::Sender; use std::{collections::HashSet, task::Poll}; // create new protocol for testing -fn protocol( -) -> (RequestResponseProtocol, RequestResponseHandle, TransportManager, Sender) -{ - let (manager, handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), BandwidthSink::new(), 8usize); - - let peer = PeerId::random(); - let (transport_service, tx) = TransportService::new( - peer, - ProtocolName::from("/notif/1"), - Vec::new(), - std::sync::Arc::new(Default::default()), - handle, - ); - let (config, handle) = - ConfigBuilder::new(ProtocolName::from("/req/1")).with_max_size(1024).build(); - - (RequestResponseProtocol::new(transport_service, config), handle, manager, tx) +fn protocol() -> ( + RequestResponseProtocol, + RequestResponseHandle, + TransportManager, + Sender, +) { + let (manager, handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let peer = PeerId::random(); + let (transport_service, tx) = TransportService::new( + peer, + ProtocolName::from("/notif/1"), + Vec::new(), + std::sync::Arc::new(Default::default()), + handle, + ); + let (config, handle) = + ConfigBuilder::new(ProtocolName::from("/req/1")).with_max_size(1024).build(); + + ( + RequestResponseProtocol::new(transport_service, config), + handle, + manager, + tx, + ) } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn connection_closed_twice() { - let (mut protocol, _handle, _manager, _tx) = protocol(); + let (mut protocol, _handle, _manager, _tx) = protocol(); - let peer = PeerId::random(); - protocol.on_connection_established(peer).await.unwrap(); - assert!(protocol.peers.contains_key(&peer)); + let peer = PeerId::random(); + protocol.on_connection_established(peer).await.unwrap(); + assert!(protocol.peers.contains_key(&peer)); - protocol.on_connection_established(peer).await.unwrap(); + protocol.on_connection_established(peer).await.unwrap(); } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn connection_established_twice() { - let (mut protocol, _handle, _manager, _tx) = protocol(); + let (mut protocol, _handle, _manager, _tx) = protocol(); - let peer = PeerId::random(); - protocol.on_connection_established(peer).await.unwrap(); - assert!(protocol.peers.contains_key(&peer)); + let peer = PeerId::random(); + protocol.on_connection_established(peer).await.unwrap(); + assert!(protocol.peers.contains_key(&peer)); - protocol.on_connection_closed(peer).await; - assert!(!protocol.peers.contains_key(&peer)); + protocol.on_connection_closed(peer).await; + assert!(!protocol.peers.contains_key(&peer)); - protocol.on_connection_closed(peer).await; + protocol.on_connection_closed(peer).await; } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn unknown_outbound_substream_opened() { - let (mut protocol, _handle, _manager, _tx) = protocol(); - let peer = PeerId::random(); - - match protocol - .on_outbound_substream( - peer, - SubstreamId::from(1337usize), - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(MockSubstream::new())), - None, - ) - .await - { - Err(Error::InvalidState) => {}, - _ => panic!("invalid return value"), - } + let (mut protocol, _handle, _manager, _tx) = protocol(); + let peer = PeerId::random(); + + match protocol + .on_outbound_substream( + peer, + SubstreamId::from(1337usize), + Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(MockSubstream::new()), + ), + None, + ) + .await + { + Err(Error::InvalidState) => {} + _ => panic!("invalid return value"), + } } #[tokio::test] #[cfg(debug_assertions)] #[should_panic] async fn unknown_substream_open_failure() { - let (mut protocol, _handle, _manager, _tx) = protocol(); - - match protocol - .on_substream_open_failure(SubstreamId::from(1338usize), Error::Unknown) - .await - { - Err(Error::InvalidState) => {}, - _ => panic!("invalid return value"), - } + let (mut protocol, _handle, _manager, _tx) = protocol(); + + match protocol + .on_substream_open_failure(SubstreamId::from(1338usize), Error::Unknown) + .await + { + Err(Error::InvalidState) => {} + _ => panic!("invalid return value"), + } } #[tokio::test] async fn cancel_unknown_request() { - let (mut protocol, _handle, _manager, _tx) = protocol(); + let (mut protocol, _handle, _manager, _tx) = protocol(); - let request_id = RequestId::from(1337usize); - assert!(!protocol.pending_outbound_cancels.contains_key(&request_id)); - assert!(protocol.on_cancel_request(request_id).await.is_ok()); + let request_id = RequestId::from(1337usize); + assert!(!protocol.pending_outbound_cancels.contains_key(&request_id)); + assert!(protocol.on_cancel_request(request_id).await.is_ok()); } #[tokio::test] async fn substream_event_for_unknown_peer() { - let (mut protocol, _handle, _manager, _tx) = protocol(); - - // register peer - let peer = PeerId::random(); - protocol.on_connection_established(peer).await.unwrap(); - assert!(protocol.peers.contains_key(&peer)); - - match protocol - .on_substream_event(peer, RequestId::from(1337usize), None, Ok(vec![13, 37])) - .await - { - Err(Error::InvalidState) => {}, - _ => panic!("invalid return value"), - } + let (mut protocol, _handle, _manager, _tx) = protocol(); + + // register peer + let peer = PeerId::random(); + protocol.on_connection_established(peer).await.unwrap(); + assert!(protocol.peers.contains_key(&peer)); + + match protocol + .on_substream_event(peer, RequestId::from(1337usize), None, Ok(vec![13, 37])) + .await + { + Err(Error::InvalidState) => {} + _ => panic!("invalid return value"), + } } #[tokio::test] async fn inbound_substream_error() { - let (mut protocol, _handle, _manager, _tx) = protocol(); - - // register peer - let peer = PeerId::random(); - protocol.on_connection_established(peer).await.unwrap(); - assert!(protocol.peers.contains_key(&peer)); - - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Err(Error::Unknown)))); - - // register inbound substream from peer - protocol - .on_inbound_substream( - peer, - None, - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), - ) - .await - .unwrap(); - - // verify the request has been registered for the peer - let request_id = *protocol.peers.get(&peer).unwrap().active_inbound.keys().next().unwrap(); - assert!(protocol.pending_inbound_requests.get_mut(&(peer, request_id)).is_some()); - - // poll the substream and get the failure event - let ((peer, request_id), event) = protocol.pending_inbound_requests.next().await.unwrap(); - - match protocol.on_inbound_request(peer, request_id, event).await { - Err(Error::InvalidData) => {}, - _ => panic!("invalid return value"), - } + let (mut protocol, _handle, _manager, _tx) = protocol(); + + // register peer + let peer = PeerId::random(); + protocol.on_connection_established(peer).await.unwrap(); + assert!(protocol.peers.contains_key(&peer)); + + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Err(Error::Unknown)))); + + // register inbound substream from peer + protocol + .on_inbound_substream( + peer, + None, + Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)), + ) + .await + .unwrap(); + + // verify the request has been registered for the peer + let request_id = *protocol.peers.get(&peer).unwrap().active_inbound.keys().next().unwrap(); + assert!(protocol.pending_inbound_requests.get_mut(&(peer, request_id)).is_some()); + + // poll the substream and get the failure event + let ((peer, request_id), event) = protocol.pending_inbound_requests.next().await.unwrap(); + + match protocol.on_inbound_request(peer, request_id, event).await { + Err(Error::InvalidData) => {} + _ => panic!("invalid return value"), + } } // when a peer who had an active inbound substream disconnects, verify that the substream is removed // from `pending_inbound_requests` so it doesn't generate new wake-up notifications #[tokio::test] async fn disconnect_peer_has_active_inbound_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut protocol, mut handle, _manager, _tx) = protocol(); - - // register new peer - let peer = PeerId::random(); - protocol.on_connection_established(peer).await.unwrap(); - - // register inbound substream from peer - protocol - .on_inbound_substream( - peer, - None, - Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(DummySubstream::new())), - ) - .await - .unwrap(); - - // verify the request has been registered for the peer - let request_id = *protocol.peers.get(&peer).unwrap().active_inbound.keys().next().unwrap(); - assert!(protocol.pending_inbound_requests.get_mut(&(peer, request_id)).is_some()); - - // disconnect the peer and verify that no events are read from the handle - // since no outbound request was initiated - protocol.on_connection_closed(peer).await; - - futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("read an unexpected event from handle: {event:?}"), - }) - .await; - - // verify the substream has been removed from `pending_inbound_requests` - assert!(protocol.pending_inbound_requests.get_mut(&(peer, request_id)).is_none()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut protocol, mut handle, _manager, _tx) = protocol(); + + // register new peer + let peer = PeerId::random(); + protocol.on_connection_established(peer).await.unwrap(); + + // register inbound substream from peer + protocol + .on_inbound_substream( + peer, + None, + Substream::new_mock( + peer, + SubstreamId::from(0usize), + Box::new(DummySubstream::new()), + ), + ) + .await + .unwrap(); + + // verify the request has been registered for the peer + let request_id = *protocol.peers.get(&peer).unwrap().active_inbound.keys().next().unwrap(); + assert!(protocol.pending_inbound_requests.get_mut(&(peer, request_id)).is_some()); + + // disconnect the peer and verify that no events are read from the handle + // since no outbound request was initiated + protocol.on_connection_closed(peer).await; + + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("read an unexpected event from handle: {event:?}"), + }) + .await; + + // verify the substream has been removed from `pending_inbound_requests` + assert!(protocol.pending_inbound_requests.get_mut(&(peer, request_id)).is_none()); } // when user initiates an outbound request and `RequestResponseProtocol` tries to open an outbound @@ -237,46 +257,50 @@ async fn disconnect_peer_has_active_inbound_substream() { // later disconnects, this failure should not be reported again. #[tokio::test] async fn request_failure_reported_once() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut protocol, mut handle, _manager, _tx) = protocol(); - - // register new peer - let peer = PeerId::random(); - protocol.on_connection_established(peer).await.unwrap(); - - // initiate outbound request - // - // since the peer wasn't properly registered, opening substream to them will fail - protocol - .on_send_request( - peer, - RequestId::from(1337usize), - vec![1, 2, 3, 4], - DialOptions::Reject, - None, - ) - .await - .unwrap(); - - match handle.next().await { - Some(RequestResponseEvent::RequestFailed { peer: request_peer, request_id, error }) => { - assert_eq!(request_peer, peer); - assert_eq!(request_id, RequestId::from(1337usize)); - assert_eq!(error, RequestResponseError::Rejected); - }, - event => panic!("unexpected event: {event:?}"), - } - - // disconnect the peer and verify that no events are read from the handle - // since the outbound request failure was already reported - protocol.on_connection_closed(peer).await; - - futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("read an unexpected event from handle: {event:?}"), - }) - .await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut protocol, mut handle, _manager, _tx) = protocol(); + + // register new peer + let peer = PeerId::random(); + protocol.on_connection_established(peer).await.unwrap(); + + // initiate outbound request + // + // since the peer wasn't properly registered, opening substream to them will fail + protocol + .on_send_request( + peer, + RequestId::from(1337usize), + vec![1, 2, 3, 4], + DialOptions::Reject, + None, + ) + .await + .unwrap(); + + match handle.next().await { + Some(RequestResponseEvent::RequestFailed { + peer: request_peer, + request_id, + error, + }) => { + assert_eq!(request_peer, peer); + assert_eq!(request_id, RequestId::from(1337usize)); + assert_eq!(error, RequestResponseError::Rejected); + } + event => panic!("unexpected event: {event:?}"), + } + + // disconnect the peer and verify that no events are read from the handle + // since the outbound request failure was already reported + protocol.on_connection_closed(peer).await; + + futures::future::poll_fn(|cx| match handle.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("read an unexpected event from handle: {event:?}"), + }) + .await; } diff --git a/src/protocol/transport_service.rs b/src/protocol/transport_service.rs index 5d75a092..ae8aabb2 100644 --- a/src/protocol/transport_service.rs +++ b/src/protocol/transport_service.rs @@ -19,11 +19,11 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - error::Error, - protocol::{connection::ConnectionHandle, InnerTransportEvent, TransportEvent}, - transport::{manager::TransportManagerHandle, Endpoint}, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - PeerId, DEFAULT_CHANNEL_SIZE, + error::Error, + protocol::{connection::ConnectionHandle, InnerTransportEvent, TransportEvent}, + transport::{manager::TransportManagerHandle, Endpoint}, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + PeerId, DEFAULT_CHANNEL_SIZE, }; use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; @@ -32,15 +32,15 @@ use multihash::Multihash; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::{HashMap, HashSet}, - fmt::Debug, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::{Context, Poll}, - time::Duration, + collections::{HashMap, HashSet}, + fmt::Debug, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, + time::Duration, }; /// Logging target for the file. @@ -56,782 +56,808 @@ const LOG_TARGET: &str = "litep2p::transport-service"; /// while the secondary connections remains open. #[derive(Debug)] struct ConnectionContext { - /// Primary connection. - primary: ConnectionHandle, + /// Primary connection. + primary: ConnectionHandle, - /// Secondary connection, if it exists. - secondary: Option, + /// Secondary connection, if it exists. + secondary: Option, } impl ConnectionContext { - /// Create new [`ConnectionContext`]. - fn new(primary: ConnectionHandle) -> Self { - Self { primary, secondary: None } - } - - /// Downgrade connection to non-active which means it will be closed - /// if there are no substreams open over it. - fn downgrade(&mut self, connection_id: &ConnectionId) { - if self.primary.connection_id() == connection_id { - self.primary.close(); - return; - } - - if let Some(handle) = &mut self.secondary { - if handle.connection_id() == connection_id { - handle.close(); - return; - } - } - - tracing::debug!( - target: LOG_TARGET, - primary = ?self.primary.connection_id(), - secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()), - ?connection_id, - "connection doesn't exist, cannot downgrade", - ); - } + /// Create new [`ConnectionContext`]. + fn new(primary: ConnectionHandle) -> Self { + Self { + primary, + secondary: None, + } + } + + /// Downgrade connection to non-active which means it will be closed + /// if there are no substreams open over it. + fn downgrade(&mut self, connection_id: &ConnectionId) { + if self.primary.connection_id() == connection_id { + self.primary.close(); + return; + } + + if let Some(handle) = &mut self.secondary { + if handle.connection_id() == connection_id { + handle.close(); + return; + } + } + + tracing::debug!( + target: LOG_TARGET, + primary = ?self.primary.connection_id(), + secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()), + ?connection_id, + "connection doesn't exist, cannot downgrade", + ); + } } /// Provides an interfaces for [`Litep2p`](crate::Litep2p) protocols to interact /// with the underlying transport protocols. #[derive(Debug)] pub struct TransportService { - /// Local peer ID. - pub(crate) local_peer_id: PeerId, + /// Local peer ID. + pub(crate) local_peer_id: PeerId, - /// Protocol. - protocol: ProtocolName, + /// Protocol. + protocol: ProtocolName, - /// Fallback names for the protocol. - fallback_names: Vec, + /// Fallback names for the protocol. + fallback_names: Vec, - /// Open connections. - connections: HashMap, + /// Open connections. + connections: HashMap, - /// Transport handle. - transport_handle: TransportManagerHandle, + /// Transport handle. + transport_handle: TransportManagerHandle, - /// RX channel for receiving events from tranports and connections. - rx: Receiver, + /// RX channel for receiving events from tranports and connections. + rx: Receiver, - /// Next substream ID. - next_substream_id: Arc, + /// Next substream ID. + next_substream_id: Arc, - /// Pending keep-alive timeouts. - keep_alive_timeouts: FuturesUnordered>, + /// Pending keep-alive timeouts. + keep_alive_timeouts: FuturesUnordered>, } impl TransportService { - /// Create new [`TransportService`]. - pub(crate) fn new( - local_peer_id: PeerId, - protocol: ProtocolName, - fallback_names: Vec, - next_substream_id: Arc, - transport_handle: TransportManagerHandle, - ) -> (Self, Sender) { - let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); - - ( - Self { - rx, - protocol, - local_peer_id, - fallback_names, - transport_handle, - next_substream_id, - connections: HashMap::new(), - keep_alive_timeouts: FuturesUnordered::new(), - }, - tx, - ) - } - - /// Handle connection established event. - fn on_connection_established( - &mut self, - peer: PeerId, - endpoint: Endpoint, - connection_id: ConnectionId, - handle: ConnectionHandle, - ) -> Option { - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?endpoint, - ?connection_id, - "connection established", - ); - - match self.connections.get_mut(&peer) { - Some(context) => match context.secondary { - Some(_) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?endpoint, - "ignoring third connection", - ); - None - }, - None => { - self.keep_alive_timeouts.push(Box::pin(async move { - tokio::time::sleep(Duration::from_secs(5)).await; - (peer, connection_id) - })); - context.secondary = Some(handle); - - None - }, - }, - None => { - self.connections.insert(peer, ConnectionContext::new(handle)); - self.keep_alive_timeouts.push(Box::pin(async move { - tokio::time::sleep(Duration::from_secs(5)).await; - (peer, connection_id) - })); - - Some(TransportEvent::ConnectionEstablished { peer, endpoint }) - }, - } - } - - /// Handle connection closed event. - fn on_connection_closed( - &mut self, - peer: PeerId, - connection_id: ConnectionId, - ) -> Option { - let Some(context) = self.connections.get_mut(&peer) else { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "connection closed to a non-existent peer", - ); - - debug_assert!(false); - return None; - }; - - // if the primary connection was closed, check if there exist a secondary connection - // and if it does, convert the secondary connection a primary connection - if context.primary.connection_id() == &connection_id { - tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "primary connection closed"); - - match context.secondary.take() { - None => { - self.connections.remove(&peer); - return Some(TransportEvent::ConnectionClosed { peer }); - }, - Some(handle) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "switch to secondary connection", - ); - - context.primary = handle; - return None; - }, - } - } - - match context.secondary.take() { - Some(handle) if handle.connection_id() == &connection_id => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "secondary connection closed", - ); - - return None; - }, - connection_state => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?connection_state, - "connection closed but it doesn't exist", - ); - - return None; - }, - } - } - - /// Dial `peer` using `PeerId`. - /// - /// Call fails if `Litep2p` doesn't have a known address for the peer. - pub fn dial(&mut self, peer: &PeerId) -> crate::Result<()> { - self.transport_handle.dial(peer) - } - - /// Dial peer using a `Multiaddr`. - /// - /// Call fails if the address is not in correct format or it contains an unsupported/disabled - /// transport. - /// - /// Calling this function is only necessary for those addresses that are discovered out-of-band - /// since `Litep2p` internally keeps track of all peer addresses it has learned through user - /// calling this function, Kademlia peer discoveries and `Identify` responses. - pub fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { - self.transport_handle.dial_address(address) - } - - /// Add one or more addresses for `peer`. - /// - /// The list is filtered for duplicates and unsupported transports. - pub fn add_known_address(&mut self, peer: &PeerId, addresses: impl Iterator) { - let addresses: HashSet = addresses - .filter_map(|address| { - if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { - Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?))) - } else { - Some(address) - } - }) - .collect(); - - self.transport_handle.add_known_address(peer, addresses.into_iter()); - } - - /// Open substream to `peer`. - /// - /// Call fails if there is no connection open to `peer` or the channel towards - /// the connection is clogged. - pub fn open_substream(&mut self, peer: PeerId) -> crate::Result { - // always prefer the primary connection - let connection = - &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?.primary; - - let permit = connection.try_get_permit().ok_or(Error::ConnectionClosed)?; - let substream_id = - SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed)); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - ?substream_id, - "open substream", - ); - - connection - .open_substream( - self.protocol.clone(), - self.fallback_names.clone(), - substream_id, - permit, - ) - .map(|_| substream_id) - } - - /// Forcibly close the connection, even if other protocols have substreams open over it. - pub fn force_close(&mut self, peer: PeerId) -> crate::Result<()> { - let connection = - &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?; - - tracing::debug!( - target: LOG_TARGET, - ?peer, - protocol = %self.protocol, - secondary = ?connection.secondary, - "forcibly closing the connection", - ); - - if let Some(ref mut connection) = connection.secondary { - let _ = connection.force_close(); - } - - connection.primary.force_close() - } + /// Create new [`TransportService`]. + pub(crate) fn new( + local_peer_id: PeerId, + protocol: ProtocolName, + fallback_names: Vec, + next_substream_id: Arc, + transport_handle: TransportManagerHandle, + ) -> (Self, Sender) { + let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE); + + ( + Self { + rx, + protocol, + local_peer_id, + fallback_names, + transport_handle, + next_substream_id, + connections: HashMap::new(), + keep_alive_timeouts: FuturesUnordered::new(), + }, + tx, + ) + } + + /// Handle connection established event. + fn on_connection_established( + &mut self, + peer: PeerId, + endpoint: Endpoint, + connection_id: ConnectionId, + handle: ConnectionHandle, + ) -> Option { + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?endpoint, + ?connection_id, + "connection established", + ); + + match self.connections.get_mut(&peer) { + Some(context) => match context.secondary { + Some(_) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?endpoint, + "ignoring third connection", + ); + None + } + None => { + self.keep_alive_timeouts.push(Box::pin(async move { + tokio::time::sleep(Duration::from_secs(5)).await; + (peer, connection_id) + })); + context.secondary = Some(handle); + + None + } + }, + None => { + self.connections.insert(peer, ConnectionContext::new(handle)); + self.keep_alive_timeouts.push(Box::pin(async move { + tokio::time::sleep(Duration::from_secs(5)).await; + (peer, connection_id) + })); + + Some(TransportEvent::ConnectionEstablished { peer, endpoint }) + } + } + } + + /// Handle connection closed event. + fn on_connection_closed( + &mut self, + peer: PeerId, + connection_id: ConnectionId, + ) -> Option { + let Some(context) = self.connections.get_mut(&peer) else { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "connection closed to a non-existent peer", + ); + + debug_assert!(false); + return None; + }; + + // if the primary connection was closed, check if there exist a secondary connection + // and if it does, convert the secondary connection a primary connection + if context.primary.connection_id() == &connection_id { + tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "primary connection closed"); + + match context.secondary.take() { + None => { + self.connections.remove(&peer); + return Some(TransportEvent::ConnectionClosed { peer }); + } + Some(handle) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "switch to secondary connection", + ); + + context.primary = handle; + return None; + } + } + } + + match context.secondary.take() { + Some(handle) if handle.connection_id() == &connection_id => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "secondary connection closed", + ); + + return None; + } + connection_state => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?connection_state, + "connection closed but it doesn't exist", + ); + + return None; + } + } + } + + /// Dial `peer` using `PeerId`. + /// + /// Call fails if `Litep2p` doesn't have a known address for the peer. + pub fn dial(&mut self, peer: &PeerId) -> crate::Result<()> { + self.transport_handle.dial(peer) + } + + /// Dial peer using a `Multiaddr`. + /// + /// Call fails if the address is not in correct format or it contains an unsupported/disabled + /// transport. + /// + /// Calling this function is only necessary for those addresses that are discovered out-of-band + /// since `Litep2p` internally keeps track of all peer addresses it has learned through user + /// calling this function, Kademlia peer discoveries and `Identify` responses. + pub fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { + self.transport_handle.dial_address(address) + } + + /// Add one or more addresses for `peer`. + /// + /// The list is filtered for duplicates and unsupported transports. + pub fn add_known_address(&mut self, peer: &PeerId, addresses: impl Iterator) { + let addresses: HashSet = addresses + .filter_map(|address| { + if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { + Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?))) + } else { + Some(address) + } + }) + .collect(); + + self.transport_handle.add_known_address(peer, addresses.into_iter()); + } + + /// Open substream to `peer`. + /// + /// Call fails if there is no connection open to `peer` or the channel towards + /// the connection is clogged. + pub fn open_substream(&mut self, peer: PeerId) -> crate::Result { + // always prefer the primary connection + let connection = + &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?.primary; + + let permit = connection.try_get_permit().ok_or(Error::ConnectionClosed)?; + let substream_id = + SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed)); + + tracing::trace!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + ?substream_id, + "open substream", + ); + + connection + .open_substream( + self.protocol.clone(), + self.fallback_names.clone(), + substream_id, + permit, + ) + .map(|_| substream_id) + } + + /// Forcibly close the connection, even if other protocols have substreams open over it. + pub fn force_close(&mut self, peer: PeerId) -> crate::Result<()> { + let connection = + &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?; + + tracing::debug!( + target: LOG_TARGET, + ?peer, + protocol = %self.protocol, + secondary = ?connection.secondary, + "forcibly closing the connection", + ); + + if let Some(ref mut connection) = connection.secondary { + let _ = connection.force_close(); + } + + connection.primary.force_close() + } } impl Stream for TransportService { - type Item = TransportEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - while let Poll::Ready(event) = self.rx.poll_recv(cx) { - match event { - None => return Poll::Ready(None), - Some(InnerTransportEvent::ConnectionEstablished { - peer, - endpoint, - sender, - connection, - }) => { - if let Some(event) = - self.on_connection_established(peer, endpoint, connection, sender) - { - return Poll::Ready(Some(event)); - } - }, - Some(InnerTransportEvent::ConnectionClosed { peer, connection }) => { - if let Some(event) = self.on_connection_closed(peer, connection) { - return Poll::Ready(Some(event)); - } - }, - Some(event) => return Poll::Ready(Some(event.into())), - } - } - - while let Poll::Ready(Some((peer, connection_id))) = - self.keep_alive_timeouts.poll_next_unpin(cx) - { - if let Some(context) = self.connections.get_mut(&peer) { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "keep-alive timeout over, downgrade connection", - ); - - context.downgrade(&connection_id); - } - } - - Poll::Pending - } + type Item = TransportEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Poll::Ready(event) = self.rx.poll_recv(cx) { + match event { + None => return Poll::Ready(None), + Some(InnerTransportEvent::ConnectionEstablished { + peer, + endpoint, + sender, + connection, + }) => { + if let Some(event) = + self.on_connection_established(peer, endpoint, connection, sender) + { + return Poll::Ready(Some(event)); + } + } + Some(InnerTransportEvent::ConnectionClosed { peer, connection }) => { + if let Some(event) = self.on_connection_closed(peer, connection) { + return Poll::Ready(Some(event)); + } + } + Some(event) => return Poll::Ready(Some(event.into())), + } + } + + while let Poll::Ready(Some((peer, connection_id))) = + self.keep_alive_timeouts.poll_next_unpin(cx) + { + if let Some(context) = self.connections.get_mut(&peer) { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "keep-alive timeout over, downgrade connection", + ); + + context.downgrade(&connection_id); + } + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - protocol::TransportService, - transport::manager::{handle::InnerTransportManagerCommand, TransportManagerHandle}, - }; - use futures::StreamExt; - use parking_lot::RwLock; - use std::collections::HashSet; - - /// Create new `TransportService` - fn transport_service( - ) -> (TransportService, Sender, Receiver) { - let (cmd_tx, cmd_rx) = channel(64); - let peer = PeerId::random(); - - let handle = TransportManagerHandle::new( - peer, - Arc::new(RwLock::new(HashMap::new())), - cmd_tx, - HashSet::new(), - Default::default(), - ); - - let (service, sender) = TransportService::new( - peer, - ProtocolName::from("/notif/1"), - Vec::new(), - Arc::new(AtomicUsize::new(0usize)), - handle, - ); - - (service, sender, cmd_rx) - } - - #[tokio::test] - async fn secondary_connection_stored() { - let (mut service, sender, _) = transport_service(); - let peer = PeerId::random(); - - // register first connection - let (cmd_tx1, _cmd_rx1) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(0usize), - endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionEstablished { peer: connected_peer, endpoint }) = - service.next().await - { - assert_eq!(connected_peer, peer); - assert_eq!(endpoint.address(), &Multiaddr::empty()); - } else { - panic!("expected event from `TransportService`"); - }; - - // register secondary connection - let (cmd_tx2, _cmd_rx2) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(1usize), - endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)), - sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), - }) - .await - .unwrap(); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); - assert_eq!( - context.secondary.as_ref().unwrap().connection_id(), - &ConnectionId::from(1usize) - ); - } - - #[tokio::test] - async fn tertiary_connection_ignored() { - let (mut service, sender, _) = transport_service(); - let peer = PeerId::random(); - - // register first connection - let (cmd_tx1, _cmd_rx1) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(0usize), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionEstablished { peer: connected_peer, endpoint }) = - service.next().await - { - assert_eq!(connected_peer, peer); - assert_eq!(endpoint.address(), &Multiaddr::empty()); - } else { - panic!("expected event from `TransportService`"); - }; - - // register secondary connection - let (cmd_tx2, _cmd_rx2) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(1usize), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)), - sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), - }) - .await - .unwrap(); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); - assert_eq!( - context.secondary.as_ref().unwrap().connection_id(), - &ConnectionId::from(1usize) - ); - - // try to register tertiary connection and verify it's ignored - let (cmd_tx3, mut cmd_rx3) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(2usize), - endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(2usize)), - sender: ConnectionHandle::new(ConnectionId::from(2usize), cmd_tx3), - }) - .await - .unwrap(); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); - assert_eq!( - context.secondary.as_ref().unwrap().connection_id(), - &ConnectionId::from(1usize) - ); - assert!(cmd_rx3.try_recv().is_err()); - } - - #[tokio::test] - async fn secondary_closing_doesnt_emit_event() { - let (mut service, sender, _) = transport_service(); - let peer = PeerId::random(); - - // register first connection - let (cmd_tx1, _cmd_rx1) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(0usize), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionEstablished { peer: connected_peer, endpoint }) = - service.next().await - { - assert_eq!(connected_peer, peer); - assert_eq!(endpoint.address(), &Multiaddr::empty()); - } else { - panic!("expected event from `TransportService`"); - }; - - // register secondary connection - let (cmd_tx2, _cmd_rx2) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(1usize), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)), - sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), - }) - .await - .unwrap(); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); - assert_eq!( - context.secondary.as_ref().unwrap().connection_id(), - &ConnectionId::from(1usize) - ); - - // close the secondary connection - sender - .send(InnerTransportEvent::ConnectionClosed { - peer, - connection: ConnectionId::from(1usize), - }) - .await - .unwrap(); - - // verify that the protocol is not notified - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - // verify that the secondary connection doesn't exist anymore - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); - assert!(context.secondary.is_none()); - } - - #[tokio::test] - async fn convert_secondary_to_primary() { - let (mut service, sender, _) = transport_service(); - let peer = PeerId::random(); - - // register first connection - let (cmd_tx1, mut cmd_rx1) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(0usize), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), - sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionEstablished { peer: connected_peer, endpoint }) = - service.next().await - { - assert_eq!(connected_peer, peer); - assert_eq!(endpoint.address(), &Multiaddr::empty()); - } else { - panic!("expected event from `TransportService`"); - }; - - // register secondary connection - let (cmd_tx2, mut cmd_rx2) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(1usize), - endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)), - sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), - }) - .await - .unwrap(); - - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); - assert_eq!( - context.secondary.as_ref().unwrap().connection_id(), - &ConnectionId::from(1usize) - ); - - // close the primary connection - sender - .send(InnerTransportEvent::ConnectionClosed { - peer, - connection: ConnectionId::from(0usize), - }) - .await - .unwrap(); - - // verify that the protocol is not notified - futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { - std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), - std::task::Poll::Pending => std::task::Poll::Ready(()), - }) - .await; - - // verify that the primary connection has been replaced - let context = service.connections.get(&peer).unwrap(); - assert_eq!(context.primary.connection_id(), &ConnectionId::from(1usize)); - assert!(context.secondary.is_none()); - assert!(cmd_rx1.try_recv().is_err()); - - // close the secondary connection as well - sender - .send(InnerTransportEvent::ConnectionClosed { - peer, - connection: ConnectionId::from(1usize), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionClosed { peer: disconnected_peer }) = - service.next().await - { - assert_eq!(disconnected_peer, peer); - } else { - panic!("expected event from `TransportService`"); - }; - - // verify that the primary connection has been replaced - assert!(service.connections.get(&peer).is_none()); - assert!(cmd_rx2.try_recv().is_err()); - } - - #[tokio::test] - async fn keep_alive_timeout_expires_for_a_stale_connection() { - let (mut service, sender, _) = transport_service(); - let peer = PeerId::random(); - - // register first connection - let (cmd_tx1, _cmd_rx1) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(1337usize), - endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)), - sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionEstablished { peer: connected_peer, endpoint }) = - service.next().await - { - assert_eq!(connected_peer, peer); - assert_eq!(endpoint.address(), &Multiaddr::empty()); - } else { - panic!("expected event from `TransportService`"); - }; - - // verify the first connection state is correct - assert_eq!(service.keep_alive_timeouts.len(), 1); - match service.connections.get(&peer) { - Some(context) => { - assert_eq!(context.primary.connection_id(), &ConnectionId::from(1337usize)); - assert!(context.secondary.is_none()); - }, - None => panic!("expected {peer} to exist"), - } - - // close the primary connection - sender - .send(InnerTransportEvent::ConnectionClosed { - peer, - connection: ConnectionId::from(1337usize), - }) - .await - .unwrap(); - - // verify that the protocols are notified of the connection closing as well - if let Some(TransportEvent::ConnectionClosed { peer: connected_peer }) = - service.next().await - { - assert_eq!(connected_peer, peer); - } else { - panic!("expected event from `TransportService`"); - } - - // verify that the keep-alive timeout still exists for the peer but the peer itself - // doesn't exist anymore - // - // the peer is removed because there is no connection to them - assert_eq!(service.keep_alive_timeouts.len(), 1); - assert!(service.connections.get(&peer).is_none()); - - // register new primary connection but verify that there are now two pending keep-alive - // timeouts - let (cmd_tx1, _cmd_rx1) = channel(64); - sender - .send(InnerTransportEvent::ConnectionEstablished { - peer, - connection: ConnectionId::from(1338usize), - endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1338usize)), - sender: ConnectionHandle::new(ConnectionId::from(1338usize), cmd_tx1), - }) - .await - .unwrap(); - - if let Some(TransportEvent::ConnectionEstablished { peer: connected_peer, endpoint }) = - service.next().await - { - assert_eq!(connected_peer, peer); - assert_eq!(endpoint.address(), &Multiaddr::empty()); - } else { - panic!("expected event from `TransportService`"); - }; - - // verify the first connection state is correct - assert_eq!(service.keep_alive_timeouts.len(), 2); - match service.connections.get(&peer) { - Some(context) => { - assert_eq!(context.primary.connection_id(), &ConnectionId::from(1338usize)); - assert!(context.secondary.is_none()); - }, - None => panic!("expected {peer} to exist"), - } - - match tokio::time::timeout(Duration::from_secs(10), service.next()).await { - Ok(event) => panic!("didn't expect an event: {event:?}"), - Err(_) => {}, - } - } + use super::*; + use crate::{ + protocol::TransportService, + transport::manager::{handle::InnerTransportManagerCommand, TransportManagerHandle}, + }; + use futures::StreamExt; + use parking_lot::RwLock; + use std::collections::HashSet; + + /// Create new `TransportService` + fn transport_service() -> ( + TransportService, + Sender, + Receiver, + ) { + let (cmd_tx, cmd_rx) = channel(64); + let peer = PeerId::random(); + + let handle = TransportManagerHandle::new( + peer, + Arc::new(RwLock::new(HashMap::new())), + cmd_tx, + HashSet::new(), + Default::default(), + ); + + let (service, sender) = TransportService::new( + peer, + ProtocolName::from("/notif/1"), + Vec::new(), + Arc::new(AtomicUsize::new(0usize)), + handle, + ); + + (service, sender, cmd_rx) + } + + #[tokio::test] + async fn secondary_connection_stored() { + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(0usize), + endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // register secondary connection + let (cmd_tx2, _cmd_rx2) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1usize), + endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)), + sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), + }) + .await + .unwrap(); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); + assert_eq!( + context.secondary.as_ref().unwrap().connection_id(), + &ConnectionId::from(1usize) + ); + } + + #[tokio::test] + async fn tertiary_connection_ignored() { + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(0usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // register secondary connection + let (cmd_tx2, _cmd_rx2) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)), + sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), + }) + .await + .unwrap(); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); + assert_eq!( + context.secondary.as_ref().unwrap().connection_id(), + &ConnectionId::from(1usize) + ); + + // try to register tertiary connection and verify it's ignored + let (cmd_tx3, mut cmd_rx3) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(2usize), + endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(2usize)), + sender: ConnectionHandle::new(ConnectionId::from(2usize), cmd_tx3), + }) + .await + .unwrap(); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); + assert_eq!( + context.secondary.as_ref().unwrap().connection_id(), + &ConnectionId::from(1usize) + ); + assert!(cmd_rx3.try_recv().is_err()); + } + + #[tokio::test] + async fn secondary_closing_doesnt_emit_event() { + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(0usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // register secondary connection + let (cmd_tx2, _cmd_rx2) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)), + sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), + }) + .await + .unwrap(); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); + assert_eq!( + context.secondary.as_ref().unwrap().connection_id(), + &ConnectionId::from(1usize) + ); + + // close the secondary connection + sender + .send(InnerTransportEvent::ConnectionClosed { + peer, + connection: ConnectionId::from(1usize), + }) + .await + .unwrap(); + + // verify that the protocol is not notified + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + // verify that the secondary connection doesn't exist anymore + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); + assert!(context.secondary.is_none()); + } + + #[tokio::test] + async fn convert_secondary_to_primary() { + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, mut cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(0usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)), + sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // register secondary connection + let (cmd_tx2, mut cmd_rx2) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1usize), + endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)), + sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2), + }) + .await + .unwrap(); + + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize)); + assert_eq!( + context.secondary.as_ref().unwrap().connection_id(), + &ConnectionId::from(1usize) + ); + + // close the primary connection + sender + .send(InnerTransportEvent::ConnectionClosed { + peer, + connection: ConnectionId::from(0usize), + }) + .await + .unwrap(); + + // verify that the protocol is not notified + futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + + // verify that the primary connection has been replaced + let context = service.connections.get(&peer).unwrap(); + assert_eq!(context.primary.connection_id(), &ConnectionId::from(1usize)); + assert!(context.secondary.is_none()); + assert!(cmd_rx1.try_recv().is_err()); + + // close the secondary connection as well + sender + .send(InnerTransportEvent::ConnectionClosed { + peer, + connection: ConnectionId::from(1usize), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionClosed { + peer: disconnected_peer, + }) = service.next().await + { + assert_eq!(disconnected_peer, peer); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify that the primary connection has been replaced + assert!(service.connections.get(&peer).is_none()); + assert!(cmd_rx2.try_recv().is_err()); + } + + #[tokio::test] + async fn keep_alive_timeout_expires_for_a_stale_connection() { + let (mut service, sender, _) = transport_service(); + let peer = PeerId::random(); + + // register first connection + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1337usize), + endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)), + sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify the first connection state is correct + assert_eq!(service.keep_alive_timeouts.len(), 1); + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1337usize) + ); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + // close the primary connection + sender + .send(InnerTransportEvent::ConnectionClosed { + peer, + connection: ConnectionId::from(1337usize), + }) + .await + .unwrap(); + + // verify that the protocols are notified of the connection closing as well + if let Some(TransportEvent::ConnectionClosed { + peer: connected_peer, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + } else { + panic!("expected event from `TransportService`"); + } + + // verify that the keep-alive timeout still exists for the peer but the peer itself + // doesn't exist anymore + // + // the peer is removed because there is no connection to them + assert_eq!(service.keep_alive_timeouts.len(), 1); + assert!(service.connections.get(&peer).is_none()); + + // register new primary connection but verify that there are now two pending keep-alive + // timeouts + let (cmd_tx1, _cmd_rx1) = channel(64); + sender + .send(InnerTransportEvent::ConnectionEstablished { + peer, + connection: ConnectionId::from(1338usize), + endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1338usize)), + sender: ConnectionHandle::new(ConnectionId::from(1338usize), cmd_tx1), + }) + .await + .unwrap(); + + if let Some(TransportEvent::ConnectionEstablished { + peer: connected_peer, + endpoint, + }) = service.next().await + { + assert_eq!(connected_peer, peer); + assert_eq!(endpoint.address(), &Multiaddr::empty()); + } else { + panic!("expected event from `TransportService`"); + }; + + // verify the first connection state is correct + assert_eq!(service.keep_alive_timeouts.len(), 2); + match service.connections.get(&peer) { + Some(context) => { + assert_eq!( + context.primary.connection_id(), + &ConnectionId::from(1338usize) + ); + assert!(context.secondary.is_none()); + } + None => panic!("expected {peer} to exist"), + } + + match tokio::time::timeout(Duration::from_secs(10), service.next()).await { + Ok(event) => panic!("didn't expect an event: {event:?}"), + Err(_) => {} + } + } } diff --git a/src/substream/mod.rs b/src/substream/mod.rs index 872ef025..514b0c02 100644 --- a/src/substream/mod.rs +++ b/src/substream/mod.rs @@ -22,11 +22,11 @@ //! Substream-related helper code. use crate::{ - codec::ProtocolCodec, - error::{Error, SubstreamError}, - transport::{quic, tcp, websocket}, - types::SubstreamId, - PeerId, + codec::ProtocolCodec, + error::{Error, SubstreamError}, + transport::{quic, tcp, websocket}, + types::SubstreamId, + PeerId, }; use bytes::{Buf, Bytes, BytesMut}; @@ -35,133 +35,133 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use unsigned_varint::{decode, encode}; use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, - fmt, - hash::Hash, - io::ErrorKind, - pin::Pin, - task::{Context, Poll}, + collections::{hash_map::Entry, HashMap, VecDeque}, + fmt, + hash::Hash, + io::ErrorKind, + pin::Pin, + task::{Context, Poll}, }; /// Logging target for the file. const LOG_TARGET: &str = "substream"; macro_rules! poll_flush { - ($substream:expr, $cx:ident) => {{ - match $substream { - SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx), - SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx), - SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx), - #[cfg(test)] - SubstreamType::Mock(_) => unreachable!(), - } - }}; + ($substream:expr, $cx:ident) => {{ + match $substream { + SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx), + SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx), + SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx), + #[cfg(test)] + SubstreamType::Mock(_) => unreachable!(), + } + }}; } macro_rules! poll_write { - ($substream:expr, $cx:ident, $frame:expr) => {{ - match $substream { - SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame), - SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame), - SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame), - #[cfg(test)] - SubstreamType::Mock(_) => unreachable!(), - } - }}; + ($substream:expr, $cx:ident, $frame:expr) => {{ + match $substream { + SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame), + SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame), + SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame), + #[cfg(test)] + SubstreamType::Mock(_) => unreachable!(), + } + }}; } macro_rules! poll_read { - ($substream:expr, $cx:ident, $buffer:expr) => {{ - match $substream { - SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer), - SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer), - SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer), - #[cfg(test)] - SubstreamType::Mock(_) => unreachable!(), - } - }}; + ($substream:expr, $cx:ident, $buffer:expr) => {{ + match $substream { + SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer), + SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer), + SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer), + #[cfg(test)] + SubstreamType::Mock(_) => unreachable!(), + } + }}; } macro_rules! poll_shutdown { - ($substream:expr, $cx:ident) => {{ - match $substream { - SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx), - SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx), - SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx), - #[cfg(test)] - SubstreamType::Mock(substream) => { - let _ = Pin::new(substream).poll_close($cx); - todo!(); - }, - } - }}; + ($substream:expr, $cx:ident) => {{ + match $substream { + SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx), + SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx), + SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx), + #[cfg(test)] + SubstreamType::Mock(substream) => { + let _ = Pin::new(substream).poll_close($cx); + todo!(); + } + } + }}; } macro_rules! delegate_poll_next { - ($substream:expr, $cx:ident) => {{ - #[cfg(test)] - if let SubstreamType::Mock(inner) = $substream { - return Pin::new(inner).poll_next($cx); - } - }}; + ($substream:expr, $cx:ident) => {{ + #[cfg(test)] + if let SubstreamType::Mock(inner) = $substream { + return Pin::new(inner).poll_next($cx); + } + }}; } macro_rules! delegate_poll_ready { - ($substream:expr, $cx:ident) => {{ - #[cfg(test)] - if let SubstreamType::Mock(inner) = $substream { - return Pin::new(inner).poll_ready($cx); - } - }}; + ($substream:expr, $cx:ident) => {{ + #[cfg(test)] + if let SubstreamType::Mock(inner) = $substream { + return Pin::new(inner).poll_ready($cx); + } + }}; } macro_rules! delegate_start_send { - ($substream:expr, $item:ident) => {{ - #[cfg(test)] - if let SubstreamType::Mock(inner) = $substream { - return Pin::new(inner).start_send($item); - } - }}; + ($substream:expr, $item:ident) => {{ + #[cfg(test)] + if let SubstreamType::Mock(inner) = $substream { + return Pin::new(inner).start_send($item); + } + }}; } macro_rules! delegate_poll_flush { - ($substream:expr, $cx:ident) => {{ - #[cfg(test)] - if let SubstreamType::Mock(inner) = $substream { - return Pin::new(inner).poll_flush($cx); - } - }}; + ($substream:expr, $cx:ident) => {{ + #[cfg(test)] + if let SubstreamType::Mock(inner) = $substream { + return Pin::new(inner).poll_flush($cx); + } + }}; } macro_rules! check_size { - ($max_size:expr, $size:expr) => {{ - if let Some(max_size) = $max_size { - if $size > max_size { - return Err(Error::IoError(ErrorKind::PermissionDenied)); - } - } - }}; + ($max_size:expr, $size:expr) => {{ + if let Some(max_size) = $max_size { + if $size > max_size { + return Err(Error::IoError(ErrorKind::PermissionDenied)); + } + } + }}; } /// Substream type. enum SubstreamType { - Tcp(tcp::Substream), - WebSocket(websocket::Substream), - Quic(quic::Substream), - #[cfg(test)] - Mock(Box), + Tcp(tcp::Substream), + WebSocket(websocket::Substream), + Quic(quic::Substream), + #[cfg(test)] + Mock(Box), } impl fmt::Debug for SubstreamType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Tcp(_) => write!(f, "Tcp"), - Self::WebSocket(_) => write!(f, "WebSocket"), - Self::Quic(_) => write!(f, "Quic"), - #[cfg(test)] - Self::Mock(_) => write!(f, "Mock"), - } - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Tcp(_) => write!(f, "Tcp"), + Self::WebSocket(_) => write!(f, "WebSocket"), + Self::Quic(_) => write!(f, "Quic"), + #[cfg(test)] + Self::Mock(_) => write!(f, "Mock"), + } + } } /// Backpressure boundary for `Sink`. @@ -176,507 +176,518 @@ const BACKPRESSURE_BOUNDARY: usize = 65536; /// [`Sink::send()`](futures::Sink)/[`Stream::next()`](futures::Stream) are also provided which /// implement the necessary framing to read/write codec-encoded messages from the underlying socket. pub struct Substream { - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - // Inner substream. - substream: SubstreamType, + // Inner substream. + substream: SubstreamType, - /// Substream ID. - substream_id: SubstreamId, + /// Substream ID. + substream_id: SubstreamId, - /// Protocol codec. - codec: ProtocolCodec, + /// Protocol codec. + codec: ProtocolCodec, - pending_out_frames: VecDeque, - pending_out_bytes: usize, - pending_out_frame: Option, + pending_out_frames: VecDeque, + pending_out_bytes: usize, + pending_out_frame: Option, - read_buffer: BytesMut, - offset: usize, - pending_frames: VecDeque, - current_frame_size: Option, + read_buffer: BytesMut, + offset: usize, + pending_frames: VecDeque, + current_frame_size: Option, - size_vec: BytesMut, + size_vec: BytesMut, } impl fmt::Debug for Substream { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Substream") - .field("peer", &self.peer) - .field("substream_id", &self.substream_id) - .field("codec", &self.codec) - .field("protocol", &self.substream) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Substream") + .field("peer", &self.peer) + .field("substream_id", &self.substream_id) + .field("codec", &self.codec) + .field("protocol", &self.substream) + .finish() + } } impl Substream { - /// Create new [`Substream`]. - fn new( - peer: PeerId, - substream_id: SubstreamId, - substream: SubstreamType, - codec: ProtocolCodec, - ) -> Self { - Self { - peer, - substream, - codec, - substream_id, - read_buffer: BytesMut::zeroed(1024), - offset: 0usize, - pending_frames: VecDeque::new(), - current_frame_size: None, - pending_out_bytes: 0usize, - pending_out_frames: VecDeque::new(), - pending_out_frame: None, - size_vec: BytesMut::zeroed(10), - } - } - - /// Create new [`Substream`] for TCP. - pub(crate) fn new_tcp( - peer: PeerId, - substream_id: SubstreamId, - substream: tcp::Substream, - codec: ProtocolCodec, - ) -> Self { - tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp"); - - Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec) - } - - /// Create new [`Substream`] for WebSocket. - pub(crate) fn new_websocket( - peer: PeerId, - substream_id: SubstreamId, - substream: websocket::Substream, - codec: ProtocolCodec, - ) -> Self { - tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket"); - - Self::new(peer, substream_id, SubstreamType::WebSocket(substream), codec) - } - - /// Create new [`Substream`] for QUIC. - pub(crate) fn new_quic( - peer: PeerId, - substream_id: SubstreamId, - substream: quic::Substream, - codec: ProtocolCodec, - ) -> Self { - tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic"); - - Self::new(peer, substream_id, SubstreamType::Quic(substream), codec) - } - - /// Create new [`Substream`] for mocking. - #[cfg(test)] - pub(crate) fn new_mock( - peer: PeerId, - substream_id: SubstreamId, - substream: Box, - ) -> Self { - tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking"); - - Self::new(peer, substream_id, SubstreamType::Mock(substream), ProtocolCodec::Unspecified) - } - - /// Close the substream. - pub async fn close(self) { - let _ = match self.substream { - SubstreamType::Tcp(mut substream) => substream.shutdown().await, - SubstreamType::WebSocket(mut substream) => substream.shutdown().await, - SubstreamType::Quic(mut substream) => substream.shutdown().await, - #[cfg(test)] - SubstreamType::Mock(mut substream) => { - let _ = futures::SinkExt::close(&mut substream).await; - Ok(()) - }, - }; - } - - /// Send identity payload to remote peer. - async fn send_identity_payload( - io: &mut T, - payload_size: usize, - payload: Bytes, - ) -> crate::Result<()> { - if payload.len() != payload_size { - return Err(Error::IoError(ErrorKind::PermissionDenied)); - } - - io.write_all(&payload) - .await - .map_err(|_| Error::SubstreamError(SubstreamError::ConnectionClosed)) - } - - /// Send framed data to remote peer. - /// - /// This function may be faster than the provided [`futures::Sink`] implementation for - /// [`Substream`] as it has direct access to the API of the underlying socket as opposed - /// to going through [`tokio::io::AsyncWrite`]. - /// - /// # Cancel safety - /// - /// This method is not cancellation safe. If that is required, use the provided - /// [`futures::Sink`] implementation. - /// - /// # Panics - /// - /// Panics if no codec is provided. - pub async fn send_framed(&mut self, mut bytes: Bytes) -> crate::Result<()> { - tracing::trace!( - target: LOG_TARGET, - peer = ?self.peer, - codec = ?self.codec, - frame_len = ?bytes.len(), - "send framed" - ); - - match &mut self.substream { - #[cfg(test)] - SubstreamType::Mock(ref mut substream) => futures::SinkExt::send(substream, bytes).await, - SubstreamType::Tcp(ref mut substream) => match self.codec { - ProtocolCodec::Unspecified => panic!("codec is unspecified"), - ProtocolCodec::Identity(payload_size) => - Self::send_identity_payload(substream, payload_size, bytes).await, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, bytes.len()); - - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); - let mut offset = 0; - - while offset < len.len() { - offset += substream.write(&len[offset..]).await?; - } - - while bytes.has_remaining() { - let nwritten = substream.write(&bytes).await?; - bytes.advance(nwritten); - } - - substream.flush().await.map_err(From::from) - }, - }, - SubstreamType::WebSocket(ref mut substream) => match self.codec { - ProtocolCodec::Unspecified => panic!("codec is unspecified"), - ProtocolCodec::Identity(payload_size) => - Self::send_identity_payload(substream, payload_size, bytes).await, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, bytes.len()); - - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); - let mut offset = 0; - - while offset < len.len() { - offset += substream.write(&len[offset..]).await?; - } - - while bytes.has_remaining() { - let nwritten = substream.write(&bytes).await?; - bytes.advance(nwritten); - } - - substream.flush().await.map_err(From::from) - }, - }, - SubstreamType::Quic(ref mut substream) => match self.codec { - ProtocolCodec::Unspecified => panic!("codec is unspecified"), - ProtocolCodec::Identity(payload_size) => - Self::send_identity_payload(substream, payload_size, bytes).await, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, bytes.len()); - - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); - let len = BytesMut::from(len); - - substream.write_all_chunks(&mut [len.freeze(), bytes]).await - }, - }, - } - } + /// Create new [`Substream`]. + fn new( + peer: PeerId, + substream_id: SubstreamId, + substream: SubstreamType, + codec: ProtocolCodec, + ) -> Self { + Self { + peer, + substream, + codec, + substream_id, + read_buffer: BytesMut::zeroed(1024), + offset: 0usize, + pending_frames: VecDeque::new(), + current_frame_size: None, + pending_out_bytes: 0usize, + pending_out_frames: VecDeque::new(), + pending_out_frame: None, + size_vec: BytesMut::zeroed(10), + } + } + + /// Create new [`Substream`] for TCP. + pub(crate) fn new_tcp( + peer: PeerId, + substream_id: SubstreamId, + substream: tcp::Substream, + codec: ProtocolCodec, + ) -> Self { + tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp"); + + Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec) + } + + /// Create new [`Substream`] for WebSocket. + pub(crate) fn new_websocket( + peer: PeerId, + substream_id: SubstreamId, + substream: websocket::Substream, + codec: ProtocolCodec, + ) -> Self { + tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket"); + + Self::new( + peer, + substream_id, + SubstreamType::WebSocket(substream), + codec, + ) + } + + /// Create new [`Substream`] for QUIC. + pub(crate) fn new_quic( + peer: PeerId, + substream_id: SubstreamId, + substream: quic::Substream, + codec: ProtocolCodec, + ) -> Self { + tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic"); + + Self::new(peer, substream_id, SubstreamType::Quic(substream), codec) + } + + /// Create new [`Substream`] for mocking. + #[cfg(test)] + pub(crate) fn new_mock( + peer: PeerId, + substream_id: SubstreamId, + substream: Box, + ) -> Self { + tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking"); + + Self::new( + peer, + substream_id, + SubstreamType::Mock(substream), + ProtocolCodec::Unspecified, + ) + } + + /// Close the substream. + pub async fn close(self) { + let _ = match self.substream { + SubstreamType::Tcp(mut substream) => substream.shutdown().await, + SubstreamType::WebSocket(mut substream) => substream.shutdown().await, + SubstreamType::Quic(mut substream) => substream.shutdown().await, + #[cfg(test)] + SubstreamType::Mock(mut substream) => { + let _ = futures::SinkExt::close(&mut substream).await; + Ok(()) + } + }; + } + + /// Send identity payload to remote peer. + async fn send_identity_payload( + io: &mut T, + payload_size: usize, + payload: Bytes, + ) -> crate::Result<()> { + if payload.len() != payload_size { + return Err(Error::IoError(ErrorKind::PermissionDenied)); + } + + io.write_all(&payload) + .await + .map_err(|_| Error::SubstreamError(SubstreamError::ConnectionClosed)) + } + + /// Send framed data to remote peer. + /// + /// This function may be faster than the provided [`futures::Sink`] implementation for + /// [`Substream`] as it has direct access to the API of the underlying socket as opposed + /// to going through [`tokio::io::AsyncWrite`]. + /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If that is required, use the provided + /// [`futures::Sink`] implementation. + /// + /// # Panics + /// + /// Panics if no codec is provided. + pub async fn send_framed(&mut self, mut bytes: Bytes) -> crate::Result<()> { + tracing::trace!( + target: LOG_TARGET, + peer = ?self.peer, + codec = ?self.codec, + frame_len = ?bytes.len(), + "send framed" + ); + + match &mut self.substream { + #[cfg(test)] + SubstreamType::Mock(ref mut substream) => + futures::SinkExt::send(substream, bytes).await, + SubstreamType::Tcp(ref mut substream) => match self.codec { + ProtocolCodec::Unspecified => panic!("codec is unspecified"), + ProtocolCodec::Identity(payload_size) => + Self::send_identity_payload(substream, payload_size, bytes).await, + ProtocolCodec::UnsignedVarint(max_size) => { + check_size!(max_size, bytes.len()); + + let mut buffer = [0u8; 10]; + let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); + let mut offset = 0; + + while offset < len.len() { + offset += substream.write(&len[offset..]).await?; + } + + while bytes.has_remaining() { + let nwritten = substream.write(&bytes).await?; + bytes.advance(nwritten); + } + + substream.flush().await.map_err(From::from) + } + }, + SubstreamType::WebSocket(ref mut substream) => match self.codec { + ProtocolCodec::Unspecified => panic!("codec is unspecified"), + ProtocolCodec::Identity(payload_size) => + Self::send_identity_payload(substream, payload_size, bytes).await, + ProtocolCodec::UnsignedVarint(max_size) => { + check_size!(max_size, bytes.len()); + + let mut buffer = [0u8; 10]; + let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); + let mut offset = 0; + + while offset < len.len() { + offset += substream.write(&len[offset..]).await?; + } + + while bytes.has_remaining() { + let nwritten = substream.write(&bytes).await?; + bytes.advance(nwritten); + } + + substream.flush().await.map_err(From::from) + } + }, + SubstreamType::Quic(ref mut substream) => match self.codec { + ProtocolCodec::Unspecified => panic!("codec is unspecified"), + ProtocolCodec::Identity(payload_size) => + Self::send_identity_payload(substream, payload_size, bytes).await, + ProtocolCodec::UnsignedVarint(max_size) => { + check_size!(max_size, bytes.len()); + + let mut buffer = [0u8; 10]; + let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); + let len = BytesMut::from(len); + + substream.write_all_chunks(&mut [len.freeze(), bytes]).await + } + }, + } + } } impl tokio::io::AsyncRead for Substream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - poll_read!(&mut self.substream, cx, buf) - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + poll_read!(&mut self.substream, cx, buf) + } } impl tokio::io::AsyncWrite for Substream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - poll_write!(&mut self.substream, cx, buf) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - poll_flush!(&mut self.substream, cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - poll_shutdown!(&mut self.substream, cx) - } + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + poll_write!(&mut self.substream, cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + poll_flush!(&mut self.substream, cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + poll_shutdown!(&mut self.substream, cx) + } } enum ReadError { - Overflow, - NotEnoughBytes, - DecodeError, + Overflow, + NotEnoughBytes, + DecodeError, } // Return the payload size and the number of bytes it took to encode it fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> { - let max_len = encode::usize_buffer().len(); - - for i in 0..std::cmp::min(buffer.len(), max_len) { - if decode::is_last(buffer[i]) { - match decode::usize(&buffer[..=i]) { - Err(_) => return Err(ReadError::DecodeError), - Ok(size) => return Ok((size.0, i + 1)), - } - } - } - - match buffer.len() < max_len { - true => Err(ReadError::NotEnoughBytes), - false => Err(ReadError::Overflow), - } + let max_len = encode::usize_buffer().len(); + + for i in 0..std::cmp::min(buffer.len(), max_len) { + if decode::is_last(buffer[i]) { + match decode::usize(&buffer[..=i]) { + Err(_) => return Err(ReadError::DecodeError), + Ok(size) => return Ok((size.0, i + 1)), + } + } + } + + match buffer.len() < max_len { + true => Err(ReadError::NotEnoughBytes), + false => Err(ReadError::Overflow), + } } impl Stream for Substream { - type Item = crate::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - - // `MockSubstream` implements `Stream` so calls to `poll_next()` must be delegated - delegate_poll_next!(&mut this.substream, cx); - - loop { - match this.codec { - ProtocolCodec::Identity(payload_size) => { - let mut read_buf = - ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]); - - match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) { - Ok(_) => { - let nread = read_buf.filled().len(); - if nread == 0 { - tracing::trace!( - target: LOG_TARGET, - peer = ?this.peer, - "read zero bytes, substream closed" - ); - return Poll::Ready(None); - } - - if nread == payload_size { - let mut payload = std::mem::replace( - &mut this.read_buffer, - BytesMut::zeroed(payload_size), - ); - payload.truncate(payload_size); - this.offset = 0usize; - - return Poll::Ready(Some(Ok(payload))); - } else { - this.offset += read_buf.filled().len(); - } - }, - Err(error) => return Poll::Ready(Some(Err(error.into()))), - } - }, - ProtocolCodec::UnsignedVarint(max_size) => { - loop { - // return all pending frames first - if let Some(frame) = this.pending_frames.pop_front() { - return Poll::Ready(Some(Ok(frame))); - } - - match this.current_frame_size.take() { - Some(frame_size) => { - let mut read_buf = - ReadBuf::new(&mut this.read_buffer[this.offset..]); - this.current_frame_size = Some(frame_size); - - match futures::ready!(poll_read!( - &mut this.substream, - cx, - &mut read_buf - )) { - Err(_error) => return Poll::Ready(None), - Ok(_) => { - let nread = match read_buf.filled().len() { - 0 => return Poll::Ready(None), - nread => nread, - }; - - this.offset += nread; - - if this.offset == frame_size { - let out_frame = std::mem::replace( - &mut this.read_buffer, - BytesMut::new(), - ); - this.offset = 0; - this.current_frame_size = None; - - return Poll::Ready(Some(Ok(out_frame))); - } else { - this.current_frame_size = Some(frame_size); - continue; - } - }, - } - }, - None => { - let mut read_buf = - ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]); - - match futures::ready!(poll_read!( - &mut this.substream, - cx, - &mut read_buf - )) { - Err(_error) => return Poll::Ready(None), - Ok(_) => { - if read_buf.filled().is_empty() { - return Poll::Ready(None); - } - this.offset += 1; - - match read_payload_size(&this.size_vec[..this.offset]) { - Err(ReadError::NotEnoughBytes) => continue, - Err(_) => - return Poll::Ready(Some(Err(Error::InvalidData))), - Ok((size, num_bytes)) => { - debug_assert_eq!(num_bytes, this.offset); - - if let Some(max_size) = max_size { - if size > max_size { - return Poll::Ready(Some(Err( - Error::InvalidData, - ))); - } - } - - this.offset = 0; - this.current_frame_size = Some(size); - this.read_buffer = BytesMut::zeroed(size); - }, - } - }, - } - }, - } - } - }, - ProtocolCodec::Unspecified => panic!("codec is unspecified"), - } - } - } + type Item = crate::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + + // `MockSubstream` implements `Stream` so calls to `poll_next()` must be delegated + delegate_poll_next!(&mut this.substream, cx); + + loop { + match this.codec { + ProtocolCodec::Identity(payload_size) => { + let mut read_buf = + ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]); + + match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) { + Ok(_) => { + let nread = read_buf.filled().len(); + if nread == 0 { + tracing::trace!( + target: LOG_TARGET, + peer = ?this.peer, + "read zero bytes, substream closed" + ); + return Poll::Ready(None); + } + + if nread == payload_size { + let mut payload = std::mem::replace( + &mut this.read_buffer, + BytesMut::zeroed(payload_size), + ); + payload.truncate(payload_size); + this.offset = 0usize; + + return Poll::Ready(Some(Ok(payload))); + } else { + this.offset += read_buf.filled().len(); + } + } + Err(error) => return Poll::Ready(Some(Err(error.into()))), + } + } + ProtocolCodec::UnsignedVarint(max_size) => { + loop { + // return all pending frames first + if let Some(frame) = this.pending_frames.pop_front() { + return Poll::Ready(Some(Ok(frame))); + } + + match this.current_frame_size.take() { + Some(frame_size) => { + let mut read_buf = + ReadBuf::new(&mut this.read_buffer[this.offset..]); + this.current_frame_size = Some(frame_size); + + match futures::ready!(poll_read!( + &mut this.substream, + cx, + &mut read_buf + )) { + Err(_error) => return Poll::Ready(None), + Ok(_) => { + let nread = match read_buf.filled().len() { + 0 => return Poll::Ready(None), + nread => nread, + }; + + this.offset += nread; + + if this.offset == frame_size { + let out_frame = std::mem::replace( + &mut this.read_buffer, + BytesMut::new(), + ); + this.offset = 0; + this.current_frame_size = None; + + return Poll::Ready(Some(Ok(out_frame))); + } else { + this.current_frame_size = Some(frame_size); + continue; + } + } + } + } + None => { + let mut read_buf = + ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]); + + match futures::ready!(poll_read!( + &mut this.substream, + cx, + &mut read_buf + )) { + Err(_error) => return Poll::Ready(None), + Ok(_) => { + if read_buf.filled().is_empty() { + return Poll::Ready(None); + } + this.offset += 1; + + match read_payload_size(&this.size_vec[..this.offset]) { + Err(ReadError::NotEnoughBytes) => continue, + Err(_) => + return Poll::Ready(Some(Err(Error::InvalidData))), + Ok((size, num_bytes)) => { + debug_assert_eq!(num_bytes, this.offset); + + if let Some(max_size) = max_size { + if size > max_size { + return Poll::Ready(Some(Err( + Error::InvalidData, + ))); + } + } + + this.offset = 0; + this.current_frame_size = Some(size); + this.read_buffer = BytesMut::zeroed(size); + } + } + } + } + } + } + } + } + ProtocolCodec::Unspecified => panic!("codec is unspecified"), + } + } + } } // TODO: this code can definitely be optimized impl Sink for Substream { - type Error = Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // `MockSubstream` implements `Sink` so calls to `poll_ready()` must be delegated - delegate_poll_ready!(&mut self.substream, cx); - - if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY { - return poll_flush!(&mut self.substream, cx).map_err(From::from); - } - - Poll::Ready(Ok(())) - } - - fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - // `MockSubstream` implements `Sink` so calls to `start_send()` must be delegated - delegate_start_send!(&mut self.substream, item); - - match self.codec { - ProtocolCodec::Identity(payload_size) => { - if item.len() != payload_size { - return Err(Error::IoError(ErrorKind::PermissionDenied)); - } - - self.pending_out_bytes += item.len(); - self.pending_out_frames.push_back(item); - }, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, item.len()); - - let len = { - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(item.len(), &mut buffer); - BytesMut::from(len) - }; - - self.pending_out_bytes += len.len() + item.len(); - self.pending_out_frames.push_back(len.freeze()); - self.pending_out_frames.push_back(item); - }, - ProtocolCodec::Unspecified => panic!("codec is unspecified"), - } - - return Ok(()); - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // `MockSubstream` implements `Sink` so calls to `poll_flush()` must be delegated - delegate_poll_flush!(&mut self.substream, cx); - - loop { - let mut pending_frame = match self.pending_out_frame.take() { - Some(frame) => frame, - None => match self.pending_out_frames.pop_front() { - Some(frame) => frame, - None => break, - }, - }; - - match poll_write!(&mut self.substream, cx, &pending_frame) { - Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())), - Poll::Pending => { - self.pending_out_frame = Some(pending_frame); - break; - }, - Poll::Ready(Ok(nwritten)) => { - pending_frame.advance(nwritten); - - if !pending_frame.is_empty() { - self.pending_out_frame = Some(pending_frame); - } - }, - } - } - - poll_flush!(&mut self.substream, cx).map_err(From::from) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - poll_shutdown!(&mut self.substream, cx).map_err(From::from) - } + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // `MockSubstream` implements `Sink` so calls to `poll_ready()` must be delegated + delegate_poll_ready!(&mut self.substream, cx); + + if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY { + return poll_flush!(&mut self.substream, cx).map_err(From::from); + } + + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + // `MockSubstream` implements `Sink` so calls to `start_send()` must be delegated + delegate_start_send!(&mut self.substream, item); + + match self.codec { + ProtocolCodec::Identity(payload_size) => { + if item.len() != payload_size { + return Err(Error::IoError(ErrorKind::PermissionDenied)); + } + + self.pending_out_bytes += item.len(); + self.pending_out_frames.push_back(item); + } + ProtocolCodec::UnsignedVarint(max_size) => { + check_size!(max_size, item.len()); + + let len = { + let mut buffer = [0u8; 10]; + let len = unsigned_varint::encode::usize(item.len(), &mut buffer); + BytesMut::from(len) + }; + + self.pending_out_bytes += len.len() + item.len(); + self.pending_out_frames.push_back(len.freeze()); + self.pending_out_frames.push_back(item); + } + ProtocolCodec::Unspecified => panic!("codec is unspecified"), + } + + return Ok(()); + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // `MockSubstream` implements `Sink` so calls to `poll_flush()` must be delegated + delegate_poll_flush!(&mut self.substream, cx); + + loop { + let mut pending_frame = match self.pending_out_frame.take() { + Some(frame) => frame, + None => match self.pending_out_frames.pop_front() { + Some(frame) => frame, + None => break, + }, + }; + + match poll_write!(&mut self.substream, cx, &pending_frame) { + Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())), + Poll::Pending => { + self.pending_out_frame = Some(pending_frame); + break; + } + Poll::Ready(Ok(nwritten)) => { + pending_frame.advance(nwritten); + + if !pending_frame.is_empty() { + self.pending_out_frame = Some(pending_frame); + } + } + } + } + + poll_flush!(&mut self.substream, cx).map_err(From::from) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + poll_shutdown!(&mut self.substream, cx).map_err(From::from) + } } /// Substream set key. @@ -688,305 +699,307 @@ impl SubstreamSetKey for K #[derive(Debug, Default)] pub struct SubstreamSet where - K: SubstreamSetKey, - S: Stream> + Unpin, + K: SubstreamSetKey, + S: Stream> + Unpin, { - substreams: HashMap, + substreams: HashMap, } impl SubstreamSet where - K: SubstreamSetKey, - S: Stream> + Unpin, + K: SubstreamSetKey, + S: Stream> + Unpin, { - /// Create new [`SubstreamSet`]. - pub fn new() -> Self { - Self { substreams: HashMap::new() } - } - - /// Add new substream to the set. - pub fn insert(&mut self, key: K, substream: S) { - match self.substreams.entry(key) { - Entry::Vacant(entry) => { - entry.insert(substream); - }, - Entry::Occupied(_) => { - tracing::error!(?key, "substream already exists"); - debug_assert!(false); - }, - } - } - - /// Remove substream from the set. - pub fn remove(&mut self, key: &K) -> Option { - self.substreams.remove(key) - } - - /// Get mutable reference to stored substream. - #[cfg(test)] - pub fn get_mut(&mut self, key: &K) -> Option<&mut S> { - self.substreams.get_mut(key) - } - - /// Get size of [`SubstreamSet`]. - pub fn len(&self) -> usize { - self.substreams.len() - } + /// Create new [`SubstreamSet`]. + pub fn new() -> Self { + Self { + substreams: HashMap::new(), + } + } + + /// Add new substream to the set. + pub fn insert(&mut self, key: K, substream: S) { + match self.substreams.entry(key) { + Entry::Vacant(entry) => { + entry.insert(substream); + } + Entry::Occupied(_) => { + tracing::error!(?key, "substream already exists"); + debug_assert!(false); + } + } + } + + /// Remove substream from the set. + pub fn remove(&mut self, key: &K) -> Option { + self.substreams.remove(key) + } + + /// Get mutable reference to stored substream. + #[cfg(test)] + pub fn get_mut(&mut self, key: &K) -> Option<&mut S> { + self.substreams.get_mut(key) + } + + /// Get size of [`SubstreamSet`]. + pub fn len(&self) -> usize { + self.substreams.len() + } } impl Stream for SubstreamSet where - K: SubstreamSetKey, - S: Stream> + Unpin, + K: SubstreamSetKey, + S: Stream> + Unpin, { - type Item = (K, ::Item); - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let inner = Pin::into_inner(self); - - // TODO: poll the streams more randomly - for (key, mut substream) in inner.substreams.iter_mut() { - match Pin::new(&mut substream).poll_next(cx) { - Poll::Pending => continue, - Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))), - Poll::Ready(None) => - return Poll::Ready(Some(( - *key, - Err(Error::SubstreamError(SubstreamError::ConnectionClosed)), - ))), - } - } - - Poll::Pending - } + type Item = (K, ::Item); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inner = Pin::into_inner(self); + + // TODO: poll the streams more randomly + for (key, mut substream) in inner.substreams.iter_mut() { + match Pin::new(&mut substream).poll_next(cx) { + Poll::Pending => continue, + Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))), + Poll::Ready(None) => + return Poll::Ready(Some(( + *key, + Err(Error::SubstreamError(SubstreamError::ConnectionClosed)), + ))), + } + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use crate::{mock::substream::MockSubstream, PeerId}; - use futures::{SinkExt, StreamExt}; - - #[test] - fn add_substream() { - let mut set = SubstreamSet::::new(); - - let peer = PeerId::random(); - let substream = MockSubstream::new(); - set.insert(peer, substream); - - let peer = PeerId::random(); - let substream = MockSubstream::new(); - set.insert(peer, substream); - } - - #[test] - #[should_panic] - #[cfg(debug_assertions)] - fn add_same_peer_twice() { - let mut set = SubstreamSet::::new(); - - let peer = PeerId::random(); - let substream1 = MockSubstream::new(); - let substream2 = MockSubstream::new(); - - set.insert(peer, substream1); - set.insert(peer, substream2); - } - - #[test] - fn remove_substream() { - let mut set = SubstreamSet::::new(); - - let peer1 = PeerId::random(); - let substream1 = MockSubstream::new(); - set.insert(peer1, substream1); - - let peer2 = PeerId::random(); - let substream2 = MockSubstream::new(); - set.insert(peer2, substream2); - - assert!(set.remove(&peer1).is_some()); - assert!(set.remove(&peer2).is_some()); - assert!(set.remove(&PeerId::random()).is_none()); - } - - #[tokio::test] - async fn poll_data_from_substream() { - let mut set = SubstreamSet::::new(); - - let peer = PeerId::random(); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..]))))); - substream.expect_poll_next().returning(|_| Poll::Pending); - set.insert(peer, substream); - - let value = set.next().await.unwrap(); - assert_eq!(value.0, peer); - assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..])); - - let value = set.next().await.unwrap(); - assert_eq!(value.0, peer); - assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..])); - - assert!(futures::poll!(set.next()).is_pending()); - } - - #[tokio::test] - async fn substream_closed() { - let mut set = SubstreamSet::::new(); - - let peer = PeerId::random(); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None)); - substream.expect_poll_next().returning(|_| Poll::Pending); - set.insert(peer, substream); - - let value = set.next().await.unwrap(); - assert_eq!(value.0, peer); - assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..])); - - match set.next().await { - Some((exited_peer, Err(Error::SubstreamError(SubstreamError::ConnectionClosed)))) => { - assert_eq!(peer, exited_peer); - }, - _ => panic!("inavlid event received"), - } - } - - #[tokio::test] - async fn get_mut_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut set = SubstreamSet::::new(); - - let peer = PeerId::random(); - let mut substream = MockSubstream::new(); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream.expect_start_send().times(1).return_once(|_| Ok(())); - substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(()))); - substream - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..]))))); - substream.expect_poll_next().returning(|_| Poll::Pending); - set.insert(peer, substream); - - let value = set.next().await.unwrap(); - assert_eq!(value.0, peer); - assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..])); - - let substream = set.get_mut(&peer).unwrap(); - substream.send(vec![1, 2, 3, 4].into()).await.unwrap(); - - let value = set.next().await.unwrap(); - assert_eq!(value.0, peer); - assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..])); - - // try to get non-existent substream - assert!(set.get_mut(&PeerId::random()).is_none()); - } - - #[tokio::test] - async fn poll_data_from_two_substreams() { - let mut set = SubstreamSet::::new(); - - // prepare first substream - let peer1 = PeerId::random(); - let mut substream1 = MockSubstream::new(); - substream1 - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); - substream1 - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..]))))); - substream1.expect_poll_next().returning(|_| Poll::Pending); - set.insert(peer1, substream1); - - // prepare second substream - let peer2 = PeerId::random(); - let mut substream2 = MockSubstream::new(); - substream2 - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..]))))); - substream2 - .expect_poll_next() - .times(1) - .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..]))))); - substream2.expect_poll_next().returning(|_| Poll::Pending); - set.insert(peer2, substream2); - - let expected: Vec> = vec![ - vec![ - (peer1, BytesMut::from(&b"hello"[..])), - (peer1, BytesMut::from(&b"world"[..])), - (peer2, BytesMut::from(&b"siip"[..])), - (peer2, BytesMut::from(&b"huup"[..])), - ], - vec![ - (peer1, BytesMut::from(&b"hello"[..])), - (peer2, BytesMut::from(&b"siip"[..])), - (peer1, BytesMut::from(&b"world"[..])), - (peer2, BytesMut::from(&b"huup"[..])), - ], - vec![ - (peer2, BytesMut::from(&b"siip"[..])), - (peer2, BytesMut::from(&b"huup"[..])), - (peer1, BytesMut::from(&b"hello"[..])), - (peer1, BytesMut::from(&b"world"[..])), - ], - vec![ - (peer1, BytesMut::from(&b"hello"[..])), - (peer2, BytesMut::from(&b"siip"[..])), - (peer2, BytesMut::from(&b"huup"[..])), - (peer1, BytesMut::from(&b"world"[..])), - ], - ]; - - // poll values - let mut values = Vec::new(); - - for _ in 0..4 { - let value = set.next().await.unwrap(); - values.push((value.0, value.1.unwrap())); - } - - let mut correct_found = false; - - for set in expected { - if values == set { - correct_found = true; - break; - } - } - - if !correct_found { - panic!("invalid set generated"); - } - - // rest of the calls return `Poll::Pending` - for _ in 0..10 { - assert!(futures::poll!(set.next()).is_pending()); - } - } + use super::*; + use crate::{mock::substream::MockSubstream, PeerId}; + use futures::{SinkExt, StreamExt}; + + #[test] + fn add_substream() { + let mut set = SubstreamSet::::new(); + + let peer = PeerId::random(); + let substream = MockSubstream::new(); + set.insert(peer, substream); + + let peer = PeerId::random(); + let substream = MockSubstream::new(); + set.insert(peer, substream); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] + fn add_same_peer_twice() { + let mut set = SubstreamSet::::new(); + + let peer = PeerId::random(); + let substream1 = MockSubstream::new(); + let substream2 = MockSubstream::new(); + + set.insert(peer, substream1); + set.insert(peer, substream2); + } + + #[test] + fn remove_substream() { + let mut set = SubstreamSet::::new(); + + let peer1 = PeerId::random(); + let substream1 = MockSubstream::new(); + set.insert(peer1, substream1); + + let peer2 = PeerId::random(); + let substream2 = MockSubstream::new(); + set.insert(peer2, substream2); + + assert!(set.remove(&peer1).is_some()); + assert!(set.remove(&peer2).is_some()); + assert!(set.remove(&PeerId::random()).is_none()); + } + + #[tokio::test] + async fn poll_data_from_substream() { + let mut set = SubstreamSet::::new(); + + let peer = PeerId::random(); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..]))))); + substream.expect_poll_next().returning(|_| Poll::Pending); + set.insert(peer, substream); + + let value = set.next().await.unwrap(); + assert_eq!(value.0, peer); + assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..])); + + let value = set.next().await.unwrap(); + assert_eq!(value.0, peer); + assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..])); + + assert!(futures::poll!(set.next()).is_pending()); + } + + #[tokio::test] + async fn substream_closed() { + let mut set = SubstreamSet::::new(); + + let peer = PeerId::random(); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None)); + substream.expect_poll_next().returning(|_| Poll::Pending); + set.insert(peer, substream); + + let value = set.next().await.unwrap(); + assert_eq!(value.0, peer); + assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..])); + + match set.next().await { + Some((exited_peer, Err(Error::SubstreamError(SubstreamError::ConnectionClosed)))) => { + assert_eq!(peer, exited_peer); + } + _ => panic!("inavlid event received"), + } + } + + #[tokio::test] + async fn get_mut_substream() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut set = SubstreamSet::::new(); + + let peer = PeerId::random(); + let mut substream = MockSubstream::new(); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream.expect_start_send().times(1).return_once(|_| Ok(())); + substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(()))); + substream + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..]))))); + substream.expect_poll_next().returning(|_| Poll::Pending); + set.insert(peer, substream); + + let value = set.next().await.unwrap(); + assert_eq!(value.0, peer); + assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..])); + + let substream = set.get_mut(&peer).unwrap(); + substream.send(vec![1, 2, 3, 4].into()).await.unwrap(); + + let value = set.next().await.unwrap(); + assert_eq!(value.0, peer); + assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..])); + + // try to get non-existent substream + assert!(set.get_mut(&PeerId::random()).is_none()); + } + + #[tokio::test] + async fn poll_data_from_two_substreams() { + let mut set = SubstreamSet::::new(); + + // prepare first substream + let peer1 = PeerId::random(); + let mut substream1 = MockSubstream::new(); + substream1 + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..]))))); + substream1 + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..]))))); + substream1.expect_poll_next().returning(|_| Poll::Pending); + set.insert(peer1, substream1); + + // prepare second substream + let peer2 = PeerId::random(); + let mut substream2 = MockSubstream::new(); + substream2 + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..]))))); + substream2 + .expect_poll_next() + .times(1) + .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..]))))); + substream2.expect_poll_next().returning(|_| Poll::Pending); + set.insert(peer2, substream2); + + let expected: Vec> = vec![ + vec![ + (peer1, BytesMut::from(&b"hello"[..])), + (peer1, BytesMut::from(&b"world"[..])), + (peer2, BytesMut::from(&b"siip"[..])), + (peer2, BytesMut::from(&b"huup"[..])), + ], + vec![ + (peer1, BytesMut::from(&b"hello"[..])), + (peer2, BytesMut::from(&b"siip"[..])), + (peer1, BytesMut::from(&b"world"[..])), + (peer2, BytesMut::from(&b"huup"[..])), + ], + vec![ + (peer2, BytesMut::from(&b"siip"[..])), + (peer2, BytesMut::from(&b"huup"[..])), + (peer1, BytesMut::from(&b"hello"[..])), + (peer1, BytesMut::from(&b"world"[..])), + ], + vec![ + (peer1, BytesMut::from(&b"hello"[..])), + (peer2, BytesMut::from(&b"siip"[..])), + (peer2, BytesMut::from(&b"huup"[..])), + (peer1, BytesMut::from(&b"world"[..])), + ], + ]; + + // poll values + let mut values = Vec::new(); + + for _ in 0..4 { + let value = set.next().await.unwrap(); + values.push((value.0, value.1.unwrap())); + } + + let mut correct_found = false; + + for set in expected { + if values == set { + correct_found = true; + break; + } + } + + if !correct_found { + panic!("invalid set generated"); + } + + // rest of the calls return `Poll::Pending` + for _ in 0..10 { + assert!(futures::poll!(set.next()).is_pending()); + } + } } diff --git a/src/transport/dummy.rs b/src/transport/dummy.rs index f2e4ec0a..b7fd0aa1 100644 --- a/src/transport/dummy.rs +++ b/src/transport/dummy.rs @@ -21,133 +21,139 @@ //! Dummy transport. use crate::{ - transport::{Transport, TransportEvent}, - types::ConnectionId, + transport::{Transport, TransportEvent}, + types::ConnectionId, }; use futures::Stream; use multiaddr::Multiaddr; use std::{ - collections::VecDeque, - pin::Pin, - task::{Context, Poll}, + collections::VecDeque, + pin::Pin, + task::{Context, Poll}, }; /// Dummy transport. pub(crate) struct DummyTransport { - /// Events. - events: VecDeque, + /// Events. + events: VecDeque, } impl DummyTransport { - /// Create new [`DummyTransport`]. - #[cfg(test)] - pub(crate) fn new() -> Self { - Self { events: VecDeque::new() } - } - - /// Inject event into `DummyTransport`. - #[cfg(test)] - pub(crate) fn inject_event(&mut self, event: TransportEvent) { - self.events.push_back(event); - } + /// Create new [`DummyTransport`]. + #[cfg(test)] + pub(crate) fn new() -> Self { + Self { + events: VecDeque::new(), + } + } + + /// Inject event into `DummyTransport`. + #[cfg(test)] + pub(crate) fn inject_event(&mut self, event: TransportEvent) { + self.events.push_back(event); + } } impl Stream for DummyTransport { - type Item = TransportEvent; + type Item = TransportEvent; - fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - if self.events.is_empty() { - return Poll::Pending; - } + fn poll_next(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + if self.events.is_empty() { + return Poll::Pending; + } - Poll::Ready(self.events.pop_front()) - } + Poll::Ready(self.events.pop_front()) + } } impl Transport for DummyTransport { - fn dial(&mut self, _: ConnectionId, _: Multiaddr) -> crate::Result<()> { - Ok(()) - } + fn dial(&mut self, _: ConnectionId, _: Multiaddr) -> crate::Result<()> { + Ok(()) + } - fn accept(&mut self, _: ConnectionId) -> crate::Result<()> { - Ok(()) - } + fn accept(&mut self, _: ConnectionId) -> crate::Result<()> { + Ok(()) + } - fn reject(&mut self, _: ConnectionId) -> crate::Result<()> { - Ok(()) - } + fn reject(&mut self, _: ConnectionId) -> crate::Result<()> { + Ok(()) + } - fn open(&mut self, _: ConnectionId, _: Vec) -> crate::Result<()> { - Ok(()) - } + fn open(&mut self, _: ConnectionId, _: Vec) -> crate::Result<()> { + Ok(()) + } - fn negotiate(&mut self, _: ConnectionId) -> crate::Result<()> { - Ok(()) - } + fn negotiate(&mut self, _: ConnectionId) -> crate::Result<()> { + Ok(()) + } - /// Cancel opening connections. - fn cancel(&mut self, _: ConnectionId) {} + /// Cancel opening connections. + fn cancel(&mut self, _: ConnectionId) {} } #[cfg(test)] mod tests { - use super::*; - use crate::{transport::Endpoint, Error, PeerId}; - use futures::StreamExt; - - #[tokio::test] - async fn pending_event() { - let mut transport = DummyTransport::new(); - - transport.inject_event(TransportEvent::DialFailure { - connection_id: ConnectionId::from(1338usize), - address: Multiaddr::empty(), - error: Error::Unknown, - }); - - let peer = PeerId::random(); - let endpoint = Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1337usize)); - - transport.inject_event(TransportEvent::ConnectionEstablished { - peer, - endpoint: endpoint.clone(), - }); - - match transport.next().await.unwrap() { - TransportEvent::DialFailure { connection_id, address, .. } => { - assert_eq!(connection_id, ConnectionId::from(1338usize)); - assert_eq!(address, Multiaddr::empty()); - }, - _ => panic!("invalid event"), - } - - match transport.next().await.unwrap() { - TransportEvent::ConnectionEstablished { - peer: event_peer, - endpoint: event_endpoint, - } => { - assert_eq!(peer, event_peer); - assert_eq!(endpoint, event_endpoint); - }, - _ => panic!("invalid event"), - } - - futures::future::poll_fn(|cx| match transport.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; - } - - #[test] - fn dummy_handle_connection_states() { - let mut transport = DummyTransport::new(); - - assert!(transport.reject(ConnectionId::new()).is_ok()); - assert!(transport.open(ConnectionId::new(), Vec::new()).is_ok()); - assert!(transport.negotiate(ConnectionId::new()).is_ok()); - transport.cancel(ConnectionId::new()); - } + use super::*; + use crate::{transport::Endpoint, Error, PeerId}; + use futures::StreamExt; + + #[tokio::test] + async fn pending_event() { + let mut transport = DummyTransport::new(); + + transport.inject_event(TransportEvent::DialFailure { + connection_id: ConnectionId::from(1338usize), + address: Multiaddr::empty(), + error: Error::Unknown, + }); + + let peer = PeerId::random(); + let endpoint = Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1337usize)); + + transport.inject_event(TransportEvent::ConnectionEstablished { + peer, + endpoint: endpoint.clone(), + }); + + match transport.next().await.unwrap() { + TransportEvent::DialFailure { + connection_id, + address, + .. + } => { + assert_eq!(connection_id, ConnectionId::from(1338usize)); + assert_eq!(address, Multiaddr::empty()); + } + _ => panic!("invalid event"), + } + + match transport.next().await.unwrap() { + TransportEvent::ConnectionEstablished { + peer: event_peer, + endpoint: event_endpoint, + } => { + assert_eq!(peer, event_peer); + assert_eq!(endpoint, event_endpoint); + } + _ => panic!("invalid event"), + } + + futures::future::poll_fn(|cx| match transport.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; + } + + #[test] + fn dummy_handle_connection_states() { + let mut transport = DummyTransport::new(); + + assert!(transport.reject(ConnectionId::new()).is_ok()); + assert!(transport.open(ConnectionId::new(), Vec::new()).is_ok()); + assert!(transport.negotiate(ConnectionId::new()).is_ok()); + transport.cancel(ConnectionId::new()); + } } diff --git a/src/transport/manager/address.rs b/src/transport/manager/address.rs index 5565d109..5ff527a3 100644 --- a/src/transport/manager/address.rs +++ b/src/transport/manager/address.rs @@ -27,405 +27,416 @@ use std::collections::{BinaryHeap, HashSet}; #[derive(Debug, Clone, Hash)] pub struct AddressRecord { - /// Address score. - score: i32, + /// Address score. + score: i32, - /// Address. - address: Multiaddr, + /// Address. + address: Multiaddr, - /// Connection ID, if specifed. - connection_id: Option, + /// Connection ID, if specifed. + connection_id: Option, } impl AsRef for AddressRecord { - fn as_ref(&self) -> &Multiaddr { - &self.address - } + fn as_ref(&self) -> &Multiaddr { + &self.address + } } impl AddressRecord { - /// Create new `AddressRecord` and if `address` doesn't contain `P2p`, - /// append the provided `PeerId` to the address. - pub fn new( - peer: &PeerId, - address: Multiaddr, - score: i32, - connection_id: Option, - ) -> Self { - let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { - address.with(Protocol::P2p( - Multihash::from_bytes(&peer.to_bytes()).ok().expect("valid peer id"), - )) - } else { - address - }; - - Self { address, score, connection_id } - } - - /// Create `AddressRecord` from `Multiaddr`. - /// - /// If `address` doesn't contain `PeerId`, return `None` to indicate that this - /// an invalid `Multiaddr` from the perspective of the `TransportManager`. - pub fn from_multiaddr(address: Multiaddr) -> Option { - if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { - return None; - } - - Some(AddressRecord { address, score: 0i32, connection_id: None }) - } - - /// Get address score. - #[cfg(test)] - pub fn score(&self) -> i32 { - self.score - } - - /// Get address. - pub fn address(&self) -> &Multiaddr { - &self.address - } - - /// Get connection ID. - pub fn connection_id(&self) -> &Option { - &self.connection_id - } - - /// Update score of an address. - pub fn update_score(&mut self, score: i32) { - self.score += score; - } - - /// Set `ConnectionId` for the [`AddressRecord`]. - pub fn set_connection_id(&mut self, connection_id: ConnectionId) { - self.connection_id = Some(connection_id); - } + /// Create new `AddressRecord` and if `address` doesn't contain `P2p`, + /// append the provided `PeerId` to the address. + pub fn new( + peer: &PeerId, + address: Multiaddr, + score: i32, + connection_id: Option, + ) -> Self { + let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { + address.with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).ok().expect("valid peer id"), + )) + } else { + address + }; + + Self { + address, + score, + connection_id, + } + } + + /// Create `AddressRecord` from `Multiaddr`. + /// + /// If `address` doesn't contain `PeerId`, return `None` to indicate that this + /// an invalid `Multiaddr` from the perspective of the `TransportManager`. + pub fn from_multiaddr(address: Multiaddr) -> Option { + if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) { + return None; + } + + Some(AddressRecord { + address, + score: 0i32, + connection_id: None, + }) + } + + /// Get address score. + #[cfg(test)] + pub fn score(&self) -> i32 { + self.score + } + + /// Get address. + pub fn address(&self) -> &Multiaddr { + &self.address + } + + /// Get connection ID. + pub fn connection_id(&self) -> &Option { + &self.connection_id + } + + /// Update score of an address. + pub fn update_score(&mut self, score: i32) { + self.score += score; + } + + /// Set `ConnectionId` for the [`AddressRecord`]. + pub fn set_connection_id(&mut self, connection_id: ConnectionId) { + self.connection_id = Some(connection_id); + } } impl PartialEq for AddressRecord { - fn eq(&self, other: &Self) -> bool { - self.score.eq(&other.score) - } + fn eq(&self, other: &Self) -> bool { + self.score.eq(&other.score) + } } impl Eq for AddressRecord {} impl PartialOrd for AddressRecord { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.score.cmp(&other.score)) - } + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.score.cmp(&other.score)) + } } impl Ord for AddressRecord { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.score.cmp(&other.score) - } + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.score.cmp(&other.score) + } } /// Store for peer addresses. #[derive(Debug)] pub struct AddressStore { - //// Addresses sorted by score. - pub by_score: BinaryHeap, + //// Addresses sorted by score. + pub by_score: BinaryHeap, - /// Addresses queryable by hashing them for faster lookup. - pub by_address: HashSet, + /// Addresses queryable by hashing them for faster lookup. + pub by_address: HashSet, } impl FromIterator for AddressStore { - fn from_iter>(iter: T) -> Self { - let mut store = AddressStore::new(); - for address in iter { - if let Some(address) = AddressRecord::from_multiaddr(address) { - store.insert(address.into()); - } - } - - store - } + fn from_iter>(iter: T) -> Self { + let mut store = AddressStore::new(); + for address in iter { + if let Some(address) = AddressRecord::from_multiaddr(address) { + store.insert(address.into()); + } + } + + store + } } impl FromIterator for AddressStore { - fn from_iter>(iter: T) -> Self { - let mut store = AddressStore::new(); - for record in iter { - store.by_address.insert(record.address.clone()); - store.by_score.push(record); - } - - store - } + fn from_iter>(iter: T) -> Self { + let mut store = AddressStore::new(); + for record in iter { + store.by_address.insert(record.address.clone()); + store.by_score.push(record); + } + + store + } } impl Extend for AddressStore { - fn extend>(&mut self, iter: T) { - for record in iter { - self.insert(record) - } - } + fn extend>(&mut self, iter: T) { + for record in iter { + self.insert(record) + } + } } impl<'a> Extend<&'a AddressRecord> for AddressStore { - fn extend>(&mut self, iter: T) { - for record in iter { - self.insert(record.clone()) - } - } + fn extend>(&mut self, iter: T) { + for record in iter { + self.insert(record.clone()) + } + } } impl AddressStore { - /// Create new [`AddressStore`]. - pub fn new() -> Self { - Self { by_score: BinaryHeap::new(), by_address: HashSet::new() } - } - - /// Check if [`AddressStore`] is empty. - pub fn is_empty(&self) -> bool { - self.by_score.is_empty() - } - - /// Check if address is already in the a - pub fn contains(&self, address: &Multiaddr) -> bool { - self.by_address.contains(address) - } - - /// Insert new address record into [`AddressStore`] with default address score. - pub fn insert(&mut self, mut record: AddressRecord) { - if self.by_address.contains(record.address()) { - return; - } - - record.connection_id = None; - self.by_address.insert(record.address.clone()); - self.by_score.push(record); - } - - /// Pop address with the highest score from [`AddressStore`]. - pub fn pop(&mut self) -> Option { - self.by_score.pop().map(|record| { - self.by_address.remove(&record.address); - record - }) - } - - /// Take at most `limit` `AddressRecord`s from [`AddressStore`]. - pub fn take(&mut self, limit: usize) -> Vec { - let mut records = Vec::new(); - - for _ in 0..limit { - match self.pop() { - Some(record) => records.push(record), - None => break, - } - } - - records - } + /// Create new [`AddressStore`]. + pub fn new() -> Self { + Self { + by_score: BinaryHeap::new(), + by_address: HashSet::new(), + } + } + + /// Check if [`AddressStore`] is empty. + pub fn is_empty(&self) -> bool { + self.by_score.is_empty() + } + + /// Check if address is already in the a + pub fn contains(&self, address: &Multiaddr) -> bool { + self.by_address.contains(address) + } + + /// Insert new address record into [`AddressStore`] with default address score. + pub fn insert(&mut self, mut record: AddressRecord) { + if self.by_address.contains(record.address()) { + return; + } + + record.connection_id = None; + self.by_address.insert(record.address.clone()); + self.by_score.push(record); + } + + /// Pop address with the highest score from [`AddressStore`]. + pub fn pop(&mut self) -> Option { + self.by_score.pop().map(|record| { + self.by_address.remove(&record.address); + record + }) + } + + /// Take at most `limit` `AddressRecord`s from [`AddressStore`]. + pub fn take(&mut self, limit: usize) -> Vec { + let mut records = Vec::new(); + + for _ in 0..limit { + match self.pop() { + Some(record) => records.push(record), + None => break, + } + } + + records + } } #[cfg(test)] mod tests { - use std::{ - collections::HashMap, - net::{Ipv4Addr, SocketAddrV4}, - }; - - use super::*; - use rand::{rngs::ThreadRng, Rng}; - - fn tcp_address_record(rng: &mut ThreadRng) -> AddressRecord { - let peer = PeerId::random(); - let address = std::net::SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new( - rng.gen_range(1..=255), - rng.gen_range(0..=255), - rng.gen_range(0..=255), - rng.gen_range(0..=255), - ), - rng.gen_range(1..=65535), - )); - let score: i32 = rng.gen(); - - AddressRecord::new( - &peer, - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - score, - None, - ) - } - - fn ws_address_record(rng: &mut ThreadRng) -> AddressRecord { - let peer = PeerId::random(); - let address = std::net::SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new( - rng.gen_range(1..=255), - rng.gen_range(0..=255), - rng.gen_range(0..=255), - rng.gen_range(0..=255), - ), - rng.gen_range(1..=65535), - )); - let score: i32 = rng.gen(); - - AddressRecord::new( - &peer, - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), - score, - None, - ) - } - - fn quic_address_record(rng: &mut ThreadRng) -> AddressRecord { - let peer = PeerId::random(); - let address = std::net::SocketAddr::V4(SocketAddrV4::new( - Ipv4Addr::new( - rng.gen_range(1..=255), - rng.gen_range(0..=255), - rng.gen_range(0..=255), - rng.gen_range(0..=255), - ), - rng.gen_range(1..=65535), - )); - let score: i32 = rng.gen(); - - AddressRecord::new( - &peer, - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Udp(address.port())) - .with(Protocol::QuicV1), - score, - None, - ) - } - - #[test] - fn take_multiple_records() { - let mut store = AddressStore::new(); - let mut rng = rand::thread_rng(); - - for _ in 0..rng.gen_range(1..5) { - store.insert(tcp_address_record(&mut rng)); - } - for _ in 0..rng.gen_range(1..5) { - store.insert(ws_address_record(&mut rng)); - } - for _ in 0..rng.gen_range(1..5) { - store.insert(quic_address_record(&mut rng)); - } - - let known_addresses = store.by_address.len(); - assert!(known_addresses >= 3); - - let taken = store.take(known_addresses - 2); - assert_eq!(known_addresses - 2, taken.len()); - assert!(!store.is_empty()); - - let mut prev: Option = None; - for record in taken { - assert!(!store.contains(record.address())); - - if let Some(previous) = prev { - assert!(previous.score > record.score); - } - - prev = Some(record); - } - } - - #[test] - fn attempt_to_take_excess_records() { - let mut store = AddressStore::new(); - let mut rng = rand::thread_rng(); - - store.insert(tcp_address_record(&mut rng)); - store.insert(ws_address_record(&mut rng)); - store.insert(quic_address_record(&mut rng)); - - assert_eq!(store.by_address.len(), 3); - - let taken = store.take(8usize); - assert_eq!(taken.len(), 3); - assert!(store.is_empty()); - - let mut prev: Option = None; - for record in taken { - if prev.is_none() { - prev = Some(record); - } else { - assert!(prev.unwrap().score > record.score); - prev = Some(record); - } - } - } - - #[test] - fn extend_from_iterator() { - let mut store = AddressStore::new(); - let mut rng = rand::thread_rng(); - - let records = (0..10) - .map(|i| { - if i % 2 == 0 { - tcp_address_record(&mut rng) - } else if i % 3 == 0 { - quic_address_record(&mut rng) - } else { - ws_address_record(&mut rng) - } - }) - .collect::>(); - - assert!(store.is_empty()); - let cloned = records - .iter() - .cloned() - .map(|record| (record.address().clone(), record)) - .collect::>(); - store.extend(records); - - for record in store.by_score { - let stored = cloned.get(record.address()).unwrap(); - assert_eq!(stored.score(), record.score()); - assert_eq!(stored.connection_id(), record.connection_id()); - assert_eq!(stored.address(), record.address()); - } - } - - #[test] - fn extend_from_iterator_ref() { - let mut store = AddressStore::new(); - let mut rng = rand::thread_rng(); - - let records = (0..10) - .map(|i| { - if i % 2 == 0 { - let record = tcp_address_record(&mut rng); - (record.address().clone(), record) - } else if i % 3 == 0 { - let record = quic_address_record(&mut rng); - (record.address().clone(), record) - } else { - let record = ws_address_record(&mut rng); - (record.address().clone(), record) - } - }) - .collect::>(); - - assert!(store.is_empty()); - let cloned = records.iter().cloned().collect::>(); - store.extend(records.iter().map(|(_, record)| record)); - - for record in store.by_score { - let stored = cloned.get(record.address()).unwrap(); - assert_eq!(stored.score(), record.score()); - assert_eq!(stored.connection_id(), record.connection_id()); - assert_eq!(stored.address(), record.address()); - } - } + use std::{ + collections::HashMap, + net::{Ipv4Addr, SocketAddrV4}, + }; + + use super::*; + use rand::{rngs::ThreadRng, Rng}; + + fn tcp_address_record(rng: &mut ThreadRng) -> AddressRecord { + let peer = PeerId::random(); + let address = std::net::SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new( + rng.gen_range(1..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + ), + rng.gen_range(1..=65535), + )); + let score: i32 = rng.gen(); + + AddressRecord::new( + &peer, + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + score, + None, + ) + } + + fn ws_address_record(rng: &mut ThreadRng) -> AddressRecord { + let peer = PeerId::random(); + let address = std::net::SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new( + rng.gen_range(1..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + ), + rng.gen_range(1..=65535), + )); + let score: i32 = rng.gen(); + + AddressRecord::new( + &peer, + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), + score, + None, + ) + } + + fn quic_address_record(rng: &mut ThreadRng) -> AddressRecord { + let peer = PeerId::random(); + let address = std::net::SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new( + rng.gen_range(1..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + ), + rng.gen_range(1..=65535), + )); + let score: i32 = rng.gen(); + + AddressRecord::new( + &peer, + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Udp(address.port())) + .with(Protocol::QuicV1), + score, + None, + ) + } + + #[test] + fn take_multiple_records() { + let mut store = AddressStore::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..rng.gen_range(1..5) { + store.insert(tcp_address_record(&mut rng)); + } + for _ in 0..rng.gen_range(1..5) { + store.insert(ws_address_record(&mut rng)); + } + for _ in 0..rng.gen_range(1..5) { + store.insert(quic_address_record(&mut rng)); + } + + let known_addresses = store.by_address.len(); + assert!(known_addresses >= 3); + + let taken = store.take(known_addresses - 2); + assert_eq!(known_addresses - 2, taken.len()); + assert!(!store.is_empty()); + + let mut prev: Option = None; + for record in taken { + assert!(!store.contains(record.address())); + + if let Some(previous) = prev { + assert!(previous.score > record.score); + } + + prev = Some(record); + } + } + + #[test] + fn attempt_to_take_excess_records() { + let mut store = AddressStore::new(); + let mut rng = rand::thread_rng(); + + store.insert(tcp_address_record(&mut rng)); + store.insert(ws_address_record(&mut rng)); + store.insert(quic_address_record(&mut rng)); + + assert_eq!(store.by_address.len(), 3); + + let taken = store.take(8usize); + assert_eq!(taken.len(), 3); + assert!(store.is_empty()); + + let mut prev: Option = None; + for record in taken { + if prev.is_none() { + prev = Some(record); + } else { + assert!(prev.unwrap().score > record.score); + prev = Some(record); + } + } + } + + #[test] + fn extend_from_iterator() { + let mut store = AddressStore::new(); + let mut rng = rand::thread_rng(); + + let records = (0..10) + .map(|i| { + if i % 2 == 0 { + tcp_address_record(&mut rng) + } else if i % 3 == 0 { + quic_address_record(&mut rng) + } else { + ws_address_record(&mut rng) + } + }) + .collect::>(); + + assert!(store.is_empty()); + let cloned = records + .iter() + .cloned() + .map(|record| (record.address().clone(), record)) + .collect::>(); + store.extend(records); + + for record in store.by_score { + let stored = cloned.get(record.address()).unwrap(); + assert_eq!(stored.score(), record.score()); + assert_eq!(stored.connection_id(), record.connection_id()); + assert_eq!(stored.address(), record.address()); + } + } + + #[test] + fn extend_from_iterator_ref() { + let mut store = AddressStore::new(); + let mut rng = rand::thread_rng(); + + let records = (0..10) + .map(|i| { + if i % 2 == 0 { + let record = tcp_address_record(&mut rng); + (record.address().clone(), record) + } else if i % 3 == 0 { + let record = quic_address_record(&mut rng); + (record.address().clone(), record) + } else { + let record = ws_address_record(&mut rng); + (record.address().clone(), record) + } + }) + .collect::>(); + + assert!(store.is_empty()); + let cloned = records.iter().cloned().collect::>(); + store.extend(records.iter().map(|(_, record)| record)); + + for record in store.by_score { + let stored = cloned.get(record.address()).unwrap(); + assert_eq!(stored.score(), record.score()); + assert_eq!(stored.connection_id(), record.connection_id()); + assert_eq!(stored.address(), record.address()); + } + } } diff --git a/src/transport/manager/handle.rs b/src/transport/manager/handle.rs index e80bbfe3..c732296c 100644 --- a/src/transport/manager/handle.rs +++ b/src/transport/manager/handle.rs @@ -19,17 +19,17 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - crypto::ed25519::Keypair, - error::{AddressError, Error}, - executor::Executor, - protocol::ProtocolSet, - transport::manager::{ - address::{AddressRecord, AddressStore}, - types::{PeerContext, PeerState, SupportedTransport}, - ProtocolContext, TransportManagerEvent, LOG_TARGET, - }, - types::{protocol::ProtocolName, ConnectionId}, - BandwidthSink, PeerId, + crypto::ed25519::Keypair, + error::{AddressError, Error}, + executor::Executor, + protocol::ProtocolSet, + transport::manager::{ + address::{AddressRecord, AddressStore}, + types::{PeerContext, PeerState, SupportedTransport}, + ProtocolContext, TransportManagerEvent, LOG_TARGET, + }, + types::{protocol::ProtocolName, ConnectionId}, + BandwidthSink, PeerId, }; use multiaddr::{Multiaddr, Protocol}; @@ -37,585 +37,597 @@ use parking_lot::RwLock; use tokio::sync::mpsc::{error::TrySendError, Sender}; use std::{ - collections::{HashMap, HashSet}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, + collections::{HashMap, HashSet}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; /// Inner commands sent from [`TransportManagerHandle`] to /// [`crate::transport::manager::TransportManager`]. pub enum InnerTransportManagerCommand { - /// Dial peer. - DialPeer { - /// Remote peer ID. - peer: PeerId, - }, - - /// Dial address. - DialAddress { - /// Remote address. - address: Multiaddr, - }, + /// Dial peer. + DialPeer { + /// Remote peer ID. + peer: PeerId, + }, + + /// Dial address. + DialAddress { + /// Remote address. + address: Multiaddr, + }, } /// Handle for communicating with [`crate::transport::manager::TransportManager`]. #[derive(Debug, Clone)] pub struct TransportManagerHandle { - /// Local peer ID. - local_peer_id: PeerId, + /// Local peer ID. + local_peer_id: PeerId, - /// Peers. - peers: Arc>>, + /// Peers. + peers: Arc>>, - /// TX channel for sending commands to [`crate::transport::manager::TransportManager`]. - cmd_tx: Sender, + /// TX channel for sending commands to [`crate::transport::manager::TransportManager`]. + cmd_tx: Sender, - /// Supported transports. - supported_transport: HashSet, + /// Supported transports. + supported_transport: HashSet, - /// Local listen addresess. - listen_addresses: Arc>>, + /// Local listen addresess. + listen_addresses: Arc>>, } impl TransportManagerHandle { - /// Create new [`TransportManagerHandle`]. - pub fn new( - local_peer_id: PeerId, - peers: Arc>>, - cmd_tx: Sender, - supported_transport: HashSet, - listen_addresses: Arc>>, - ) -> Self { - Self { peers, cmd_tx, local_peer_id, listen_addresses, supported_transport } - } - - /// Register new transport to [`TransportManagerHandle`]. - pub(crate) fn register_transport(&mut self, transport: SupportedTransport) { - self.supported_transport.insert(transport); - } - - /// Check if `address` is supported by one of the enabled transports. - pub fn supported_transport(&self, address: &Multiaddr) -> bool { - let mut iter = address.iter(); - - match iter.next() { - Some(Protocol::Ip4(address)) => - if address.is_unspecified() { - return false; - }, - Some(Protocol::Ip6(address)) => - if address.is_unspecified() { - return false; - }, - Some(Protocol::Dns(_)) | Some(Protocol::Dns4(_)) | Some(Protocol::Dns6(_)) => {}, - _ => return false, - } - - match iter.next() { - None => return false, - Some(Protocol::Tcp(_)) => match ( - iter.next(), - self.supported_transport.contains(&SupportedTransport::WebSocket), - ) { - (Some(Protocol::Ws(_)), true) => true, - (Some(Protocol::Wss(_)), true) => true, - (Some(Protocol::P2p(_)), _) => - self.supported_transport.contains(&SupportedTransport::Tcp), - _ => return false, - }, - Some(Protocol::Udp(_)) => - match (iter.next(), self.supported_transport.contains(&SupportedTransport::Quic)) { - (Some(Protocol::QuicV1), true) => true, - _ => false, - }, - _ => false, - } - } - - /// Check if the address is a local listen address and if so, discard it. - fn is_local_address(&self, address: &Multiaddr) -> bool { - let address: Multiaddr = address - .iter() - .take_while(|protocol| !std::matches!(protocol, Protocol::P2p(_))) - .collect(); - - self.listen_addresses.read().contains(&address) - } - - /// Add one or more known addresses for peer. - /// - /// If peer doesn't exist, it will be added to known peers. - /// - /// Returns the number of added addresses after non-supported transports were filtered out. - pub fn add_known_address( - &mut self, - peer: &PeerId, - addresses: impl Iterator, - ) -> usize { - let mut peers = self.peers.write(); - let addresses = addresses - .filter_map(|address| { - (self.supported_transport(&address) && !self.is_local_address(&address)) - .then_some(AddressRecord::from_multiaddr(address)?) - }) - .collect::>(); - - // if all of the added addresses belonged to unsupported transports, exit early - let num_added = addresses.len(); - if num_added == 0 { - tracing::debug!( - target: LOG_TARGET, - ?peer, - "didn't add any addresses for peer because transport is not supported", - ); - - return 0usize; - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?addresses, - "add known addresses", - ); - - match peers.get_mut(&peer) { - Some(context) => - for record in addresses { - if !context.addresses.contains(record.address()) { - context.addresses.insert(record); - } - }, - None => { - peers.insert( - *peer, - PeerContext { - state: PeerState::Disconnected { dial_record: None }, - addresses: AddressStore::from_iter(addresses.into_iter()), - secondary_connection: None, - }, - ); - }, - } - - num_added - } - - /// Dial peer using `PeerId`. - /// - /// Returns an error if the peer is unknown or the peer is already connected. - pub fn dial(&self, peer: &PeerId) -> crate::Result<()> { - if peer == &self.local_peer_id { - return Err(Error::TriedToDialSelf); - } - - { - match self.peers.read().get(&peer) { - Some(PeerContext { state: PeerState::Connected { .. }, .. }) => - return Err(Error::AlreadyConnected), - Some(PeerContext { - state: PeerState::Disconnected { dial_record }, - addresses, - .. - }) => { - if addresses.is_empty() { - return Err(Error::NoAddressAvailable(*peer)); - } - - // peer is already being dialed, don't dial again until the first dial concluded - if dial_record.is_some() { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?dial_record, - "peer is aready being dialed", - ); - return Ok(()); - } - }, - Some(PeerContext { - state: PeerState::Dialing { .. } | PeerState::Opening { .. }, - .. - }) => return Ok(()), - None => return Err(Error::PeerDoesntExist(*peer)), - } - } - - self.cmd_tx - .try_send(InnerTransportManagerCommand::DialPeer { peer: *peer }) - .map_err(|error| match error { - TrySendError::Full(_) => Error::ChannelClogged, - TrySendError::Closed(_) => Error::EssentialTaskClosed, - }) - } - - /// Dial peer using `Multiaddr`. - /// - /// Returns an error if address it not valid. - pub fn dial_address(&self, address: Multiaddr) -> crate::Result<()> { - if !address.iter().any(|protocol| std::matches!(protocol, Protocol::P2p(_))) { - return Err(Error::AddressError(AddressError::PeerIdMissing)); - } - - self.cmd_tx - .try_send(InnerTransportManagerCommand::DialAddress { address }) - .map_err(|error| match error { - TrySendError::Full(_) => Error::ChannelClogged, - TrySendError::Closed(_) => Error::EssentialTaskClosed, - }) - } + /// Create new [`TransportManagerHandle`]. + pub fn new( + local_peer_id: PeerId, + peers: Arc>>, + cmd_tx: Sender, + supported_transport: HashSet, + listen_addresses: Arc>>, + ) -> Self { + Self { + peers, + cmd_tx, + local_peer_id, + listen_addresses, + supported_transport, + } + } + + /// Register new transport to [`TransportManagerHandle`]. + pub(crate) fn register_transport(&mut self, transport: SupportedTransport) { + self.supported_transport.insert(transport); + } + + /// Check if `address` is supported by one of the enabled transports. + pub fn supported_transport(&self, address: &Multiaddr) -> bool { + let mut iter = address.iter(); + + match iter.next() { + Some(Protocol::Ip4(address)) => + if address.is_unspecified() { + return false; + }, + Some(Protocol::Ip6(address)) => + if address.is_unspecified() { + return false; + }, + Some(Protocol::Dns(_)) | Some(Protocol::Dns4(_)) | Some(Protocol::Dns6(_)) => {} + _ => return false, + } + + match iter.next() { + None => return false, + Some(Protocol::Tcp(_)) => match ( + iter.next(), + self.supported_transport.contains(&SupportedTransport::WebSocket), + ) { + (Some(Protocol::Ws(_)), true) => true, + (Some(Protocol::Wss(_)), true) => true, + (Some(Protocol::P2p(_)), _) => + self.supported_transport.contains(&SupportedTransport::Tcp), + _ => return false, + }, + Some(Protocol::Udp(_)) => match ( + iter.next(), + self.supported_transport.contains(&SupportedTransport::Quic), + ) { + (Some(Protocol::QuicV1), true) => true, + _ => false, + }, + _ => false, + } + } + + /// Check if the address is a local listen address and if so, discard it. + fn is_local_address(&self, address: &Multiaddr) -> bool { + let address: Multiaddr = address + .iter() + .take_while(|protocol| !std::matches!(protocol, Protocol::P2p(_))) + .collect(); + + self.listen_addresses.read().contains(&address) + } + + /// Add one or more known addresses for peer. + /// + /// If peer doesn't exist, it will be added to known peers. + /// + /// Returns the number of added addresses after non-supported transports were filtered out. + pub fn add_known_address( + &mut self, + peer: &PeerId, + addresses: impl Iterator, + ) -> usize { + let mut peers = self.peers.write(); + let addresses = addresses + .filter_map(|address| { + (self.supported_transport(&address) && !self.is_local_address(&address)) + .then_some(AddressRecord::from_multiaddr(address)?) + }) + .collect::>(); + + // if all of the added addresses belonged to unsupported transports, exit early + let num_added = addresses.len(); + if num_added == 0 { + tracing::debug!( + target: LOG_TARGET, + ?peer, + "didn't add any addresses for peer because transport is not supported", + ); + + return 0usize; + } + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?addresses, + "add known addresses", + ); + + match peers.get_mut(&peer) { + Some(context) => + for record in addresses { + if !context.addresses.contains(record.address()) { + context.addresses.insert(record); + } + }, + None => { + peers.insert( + *peer, + PeerContext { + state: PeerState::Disconnected { dial_record: None }, + addresses: AddressStore::from_iter(addresses.into_iter()), + secondary_connection: None, + }, + ); + } + } + + num_added + } + + /// Dial peer using `PeerId`. + /// + /// Returns an error if the peer is unknown or the peer is already connected. + pub fn dial(&self, peer: &PeerId) -> crate::Result<()> { + if peer == &self.local_peer_id { + return Err(Error::TriedToDialSelf); + } + + { + match self.peers.read().get(&peer) { + Some(PeerContext { + state: PeerState::Connected { .. }, + .. + }) => return Err(Error::AlreadyConnected), + Some(PeerContext { + state: PeerState::Disconnected { dial_record }, + addresses, + .. + }) => { + if addresses.is_empty() { + return Err(Error::NoAddressAvailable(*peer)); + } + + // peer is already being dialed, don't dial again until the first dial concluded + if dial_record.is_some() { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?dial_record, + "peer is aready being dialed", + ); + return Ok(()); + } + } + Some(PeerContext { + state: PeerState::Dialing { .. } | PeerState::Opening { .. }, + .. + }) => return Ok(()), + None => return Err(Error::PeerDoesntExist(*peer)), + } + } + + self.cmd_tx + .try_send(InnerTransportManagerCommand::DialPeer { peer: *peer }) + .map_err(|error| match error { + TrySendError::Full(_) => Error::ChannelClogged, + TrySendError::Closed(_) => Error::EssentialTaskClosed, + }) + } + + /// Dial peer using `Multiaddr`. + /// + /// Returns an error if address it not valid. + pub fn dial_address(&self, address: Multiaddr) -> crate::Result<()> { + if !address.iter().any(|protocol| std::matches!(protocol, Protocol::P2p(_))) { + return Err(Error::AddressError(AddressError::PeerIdMissing)); + } + + self.cmd_tx + .try_send(InnerTransportManagerCommand::DialAddress { address }) + .map_err(|error| match error { + TrySendError::Full(_) => Error::ChannelClogged, + TrySendError::Closed(_) => Error::EssentialTaskClosed, + }) + } } // TODO: add getters for these pub struct TransportHandle { - pub keypair: Keypair, - pub tx: Sender, - pub protocols: HashMap, - pub next_connection_id: Arc, - pub next_substream_id: Arc, - pub protocol_names: Vec, - pub bandwidth_sink: BandwidthSink, - pub executor: Arc, + pub keypair: Keypair, + pub tx: Sender, + pub protocols: HashMap, + pub next_connection_id: Arc, + pub next_substream_id: Arc, + pub protocol_names: Vec, + pub bandwidth_sink: BandwidthSink, + pub executor: Arc, } impl TransportHandle { - pub fn protocol_set(&self, connection_id: ConnectionId) -> ProtocolSet { - ProtocolSet::new( - connection_id, - self.tx.clone(), - self.next_substream_id.clone(), - self.protocols.clone(), - ) - } - - /// Get next connection ID. - pub fn next_connection_id(&mut self) -> ConnectionId { - let connection_id = self.next_connection_id.fetch_add(1usize, Ordering::Relaxed); - - ConnectionId::from(connection_id) - } + pub fn protocol_set(&self, connection_id: ConnectionId) -> ProtocolSet { + ProtocolSet::new( + connection_id, + self.tx.clone(), + self.next_substream_id.clone(), + self.protocols.clone(), + ) + } + + /// Get next connection ID. + pub fn next_connection_id(&mut self) -> ConnectionId { + let connection_id = self.next_connection_id.fetch_add(1usize, Ordering::Relaxed); + + ConnectionId::from(connection_id) + } } #[cfg(test)] mod tests { - use super::*; - use multihash::Multihash; - use tokio::sync::mpsc::{channel, Receiver}; - - fn make_transport_manager_handle( - ) -> (TransportManagerHandle, Receiver) { - let (cmd_tx, cmd_rx) = channel(64); - - ( - TransportManagerHandle { - local_peer_id: PeerId::random(), - cmd_tx, - peers: Default::default(), - supported_transport: HashSet::new(), - listen_addresses: Default::default(), - }, - cmd_rx, - ) - } - - #[tokio::test] - async fn tcp_and_websocket_supported() { - let (mut handle, _rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Tcp); - handle.supported_transport.insert(SupportedTransport::WebSocket); - - let address = + use super::*; + use multihash::Multihash; + use tokio::sync::mpsc::{channel, Receiver}; + + fn make_transport_manager_handle() -> ( + TransportManagerHandle, + Receiver, + ) { + let (cmd_tx, cmd_rx) = channel(64); + + ( + TransportManagerHandle { + local_peer_id: PeerId::random(), + cmd_tx, + peers: Default::default(), + supported_transport: HashSet::new(), + listen_addresses: Default::default(), + }, + cmd_rx, + ) + } + + #[tokio::test] + async fn tcp_and_websocket_supported() { + let (mut handle, _rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Tcp); + handle.supported_transport.insert(SupportedTransport::WebSocket); + + let address = "/dns4/google.com/tcp/24928/ws/p2p/12D3KooWKrUnV42yDR7G6DewmgHtFaVCJWLjQRi2G9t5eJD3BvTy" .parse() .unwrap(); - assert!(handle.supported_transport(&address)); - } - - #[test] - fn transport_not_supported() { - let (handle, _rx) = make_transport_manager_handle(); - - // only peer id (used by Polkadot sometimes) - assert!(!handle.supported_transport( - &Multiaddr::empty().with(Protocol::P2p(Multihash::from(PeerId::random()))) - )); - - // only one transport - assert!(!handle.supported_transport( - &Multiaddr::empty().with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - )); - - // any udp-based protocol other than quic - assert!(!handle.supported_transport( - &Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(8888)) - .with(Protocol::Utp) - )); - - // any other protocol other than tcp - assert!(!handle.supported_transport( - &Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Sctp(8888)) - )); - } - - #[test] - fn zero_addresses_added() { - let (mut handle, _rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Quic); - - assert!( - handle.add_known_address( - &PeerId::random(), - vec![ - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(8888)) - .with(Protocol::Utp), - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)), - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::Wss(std::borrow::Cow::Owned("/".to_string()))), - ] - .into_iter() - ) == 0usize - ); - } - - #[tokio::test] - async fn dial_already_connected_peer() { - let (mut handle, _rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Tcp); - - let peer = { - let peer = PeerId::random(); - let mut peers = handle.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Connected { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - dial_record: None, - }, - secondary_connection: None, - addresses: AddressStore::from_iter( - vec![Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ), - }, - ); - drop(peers); - - peer - }; - - match handle.dial(&peer) { - Err(Error::AlreadyConnected) => {}, - _ => panic!("invalid return value"), - } - } - - #[tokio::test] - async fn peer_already_being_dialed() { - let (mut handle, _rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Tcp); - - let peer = { - let peer = PeerId::random(); - let mut peers = handle.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Dialing { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - }, - secondary_connection: None, - addresses: AddressStore::from_iter( - vec![Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ), - }, - ); - drop(peers); - - peer - }; - - match handle.dial(&peer) { - Ok(()) => {}, - _ => panic!("invalid return value"), - } - } - - #[tokio::test] - async fn no_address_available_for_peer() { - let (mut handle, _rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Tcp); - - let peer = { - let peer = PeerId::random(); - let mut peers = handle.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Disconnected { dial_record: None }, - secondary_connection: None, - addresses: AddressStore::new(), - }, - ); - drop(peers); - - peer - }; - - match handle.dial(&peer) { - Err(Error::NoAddressAvailable(failed_peer)) => { - assert_eq!(failed_peer, peer); - }, - _ => panic!("invalid return value"), - } - } - - #[tokio::test] - async fn pending_connection_for_disconnected_peer() { - let (mut handle, mut rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Tcp); - - let peer = { - let peer = PeerId::random(); - let mut peers = handle.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Disconnected { - dial_record: Some( - AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - ), - }, - secondary_connection: None, - addresses: AddressStore::from_iter( - vec![Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ), - }, - ); - drop(peers); - - peer - }; - - match handle.dial(&peer) { - Ok(()) => {}, - _ => panic!("invalid return value"), - } - assert!(rx.try_recv().is_err()); - } - - #[tokio::test] - async fn try_to_dial_self() { - let (mut handle, mut rx) = make_transport_manager_handle(); - handle.supported_transport.insert(SupportedTransport::Tcp); - - match handle.dial(&handle.local_peer_id) { - Err(Error::TriedToDialSelf) => {}, - _ => panic!("invalid return value"), - } - assert!(rx.try_recv().is_err()); - } - - #[test] - fn is_local_address() { - let (cmd_tx, _cmd_rx) = channel(64); - - let handle = TransportManagerHandle { - local_peer_id: PeerId::random(), - cmd_tx, - peers: Default::default(), - supported_transport: HashSet::new(), - listen_addresses: Arc::new(RwLock::new(HashSet::from_iter([ - "/ip6/::1/tcp/8888".parse().expect("valid multiaddress"), - "/ip4/127.0.0.1/tcp/8888".parse().expect("valid multiaddress"), - "/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress"), - "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress"), - ]))), - }; - - // local addresses - assert!(handle.is_local_address( - &"/ip6/::1/tcp/8888".parse::().expect("valid multiaddress") - )); - assert!(handle - .is_local_address(&"/ip4/127.0.0.1/tcp/8888".parse().expect("valid multiaddress"))); - assert!(handle.is_local_address( - &"/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - )); - assert!(handle.is_local_address( - &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - )); - - // same address but different peer id - assert!(handle.is_local_address( - &"/ip6/::1/tcp/8888/p2p/12D3KooWPGxxxQiBEBZ52RY31Z2chn4xsDrGCMouZ88izJrak2T1" - .parse::() - .expect("valid multiaddress") - )); - assert!(handle.is_local_address( - &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWPGxxxQiBEBZ52RY31Z2chn4xsDrGCMouZ88izJrak2T1" - .parse() - .expect("valid multiaddress") - )); - - // different address - assert!(!handle - .is_local_address(&"/ip4/127.0.0.1/tcp/9999".parse().expect("valid multiaddress"))); - // different address - assert!(!handle - .is_local_address(&"/ip4/127.0.0.1/tcp/7777".parse().expect("valid multiaddress"))); - } + assert!(handle.supported_transport(&address)); + } + + #[test] + fn transport_not_supported() { + let (handle, _rx) = make_transport_manager_handle(); + + // only peer id (used by Polkadot sometimes) + assert!(!handle.supported_transport( + &Multiaddr::empty().with(Protocol::P2p(Multihash::from(PeerId::random()))) + )); + + // only one transport + assert!(!handle.supported_transport( + &Multiaddr::empty().with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + )); + + // any udp-based protocol other than quic + assert!(!handle.supported_transport( + &Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(8888)) + .with(Protocol::Utp) + )); + + // any other protocol other than tcp + assert!(!handle.supported_transport( + &Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Sctp(8888)) + )); + } + + #[test] + fn zero_addresses_added() { + let (mut handle, _rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Quic); + + assert!( + handle.add_known_address( + &PeerId::random(), + vec![ + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(8888)) + .with(Protocol::Utp), + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)), + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::Wss(std::borrow::Cow::Owned("/".to_string()))), + ] + .into_iter() + ) == 0usize + ); + } + + #[tokio::test] + async fn dial_already_connected_peer() { + let (mut handle, _rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Tcp); + + let peer = { + let peer = PeerId::random(); + let mut peers = handle.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Connected { + record: AddressRecord::from_multiaddr( + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ) + .unwrap(), + dial_record: None, + }, + secondary_connection: None, + addresses: AddressStore::from_iter( + vec![Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ), + }, + ); + drop(peers); + + peer + }; + + match handle.dial(&peer) { + Err(Error::AlreadyConnected) => {} + _ => panic!("invalid return value"), + } + } + + #[tokio::test] + async fn peer_already_being_dialed() { + let (mut handle, _rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Tcp); + + let peer = { + let peer = PeerId::random(); + let mut peers = handle.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Dialing { + record: AddressRecord::from_multiaddr( + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ) + .unwrap(), + }, + secondary_connection: None, + addresses: AddressStore::from_iter( + vec![Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ), + }, + ); + drop(peers); + + peer + }; + + match handle.dial(&peer) { + Ok(()) => {} + _ => panic!("invalid return value"), + } + } + + #[tokio::test] + async fn no_address_available_for_peer() { + let (mut handle, _rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Tcp); + + let peer = { + let peer = PeerId::random(); + let mut peers = handle.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Disconnected { dial_record: None }, + secondary_connection: None, + addresses: AddressStore::new(), + }, + ); + drop(peers); + + peer + }; + + match handle.dial(&peer) { + Err(Error::NoAddressAvailable(failed_peer)) => { + assert_eq!(failed_peer, peer); + } + _ => panic!("invalid return value"), + } + } + + #[tokio::test] + async fn pending_connection_for_disconnected_peer() { + let (mut handle, mut rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Tcp); + + let peer = { + let peer = PeerId::random(); + let mut peers = handle.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Disconnected { + dial_record: Some( + AddressRecord::from_multiaddr( + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ) + .unwrap(), + ), + }, + secondary_connection: None, + addresses: AddressStore::from_iter( + vec![Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ), + }, + ); + drop(peers); + + peer + }; + + match handle.dial(&peer) { + Ok(()) => {} + _ => panic!("invalid return value"), + } + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn try_to_dial_self() { + let (mut handle, mut rx) = make_transport_manager_handle(); + handle.supported_transport.insert(SupportedTransport::Tcp); + + match handle.dial(&handle.local_peer_id) { + Err(Error::TriedToDialSelf) => {} + _ => panic!("invalid return value"), + } + assert!(rx.try_recv().is_err()); + } + + #[test] + fn is_local_address() { + let (cmd_tx, _cmd_rx) = channel(64); + + let handle = TransportManagerHandle { + local_peer_id: PeerId::random(), + cmd_tx, + peers: Default::default(), + supported_transport: HashSet::new(), + listen_addresses: Arc::new(RwLock::new(HashSet::from_iter([ + "/ip6/::1/tcp/8888".parse().expect("valid multiaddress"), + "/ip4/127.0.0.1/tcp/8888".parse().expect("valid multiaddress"), + "/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress"), + "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress"), + ]))), + }; + + // local addresses + assert!(handle.is_local_address( + &"/ip6/::1/tcp/8888".parse::().expect("valid multiaddress") + )); + assert!(handle + .is_local_address(&"/ip4/127.0.0.1/tcp/8888".parse().expect("valid multiaddress"))); + assert!(handle.is_local_address( + &"/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + )); + assert!(handle.is_local_address( + &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + )); + + // same address but different peer id + assert!(handle.is_local_address( + &"/ip6/::1/tcp/8888/p2p/12D3KooWPGxxxQiBEBZ52RY31Z2chn4xsDrGCMouZ88izJrak2T1" + .parse::() + .expect("valid multiaddress") + )); + assert!(handle.is_local_address( + &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWPGxxxQiBEBZ52RY31Z2chn4xsDrGCMouZ88izJrak2T1" + .parse() + .expect("valid multiaddress") + )); + + // different address + assert!(!handle + .is_local_address(&"/ip4/127.0.0.1/tcp/9999".parse().expect("valid multiaddress"))); + // different address + assert!(!handle + .is_local_address(&"/ip4/127.0.0.1/tcp/7777".parse().expect("valid multiaddress"))); + } } diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 74aa151d..61d7bdd5 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -19,21 +19,21 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - codec::ProtocolCodec, - crypto::ed25519::Keypair, - error::{AddressError, Error}, - executor::Executor, - protocol::{InnerTransportEvent, TransportService}, - transport::{ - manager::{ - address::{AddressRecord, AddressStore}, - handle::InnerTransportManagerCommand, - types::{PeerContext, PeerState}, - }, - Endpoint, Transport, TransportEvent, - }, - types::{protocol::ProtocolName, ConnectionId}, - BandwidthSink, PeerId, + codec::ProtocolCodec, + crypto::ed25519::Keypair, + error::{AddressError, Error}, + executor::Executor, + protocol::{InnerTransportEvent, TransportService}, + transport::{ + manager::{ + address::{AddressRecord, AddressStore}, + handle::InnerTransportManagerCommand, + types::{PeerContext, PeerState}, + }, + Endpoint, Transport, TransportEvent, + }, + types::{protocol::ProtocolName, ConnectionId}, + BandwidthSink, PeerId, }; use futures::{Stream, StreamExt}; @@ -44,13 +44,13 @@ use parking_lot::RwLock; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::{HashMap, HashSet}, - pin::Pin, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - task::{Context, Poll}, + collections::{HashMap, HashSet}, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + task::{Context, Poll}, }; pub use handle::{TransportHandle, TransportManagerHandle}; @@ -77,2928 +77,3163 @@ const SCORE_DIAL_FAILURE: i32 = -100i32; /// TODO: enum ConnectionEstablishedResult { - /// Accept connection and inform `Litep2p` about the connection. - Accept, + /// Accept connection and inform `Litep2p` about the connection. + Accept, - /// Reject connection. - Reject, + /// Reject connection. + Reject, } /// [`crate::transport::manager::TransportManager`] events. pub enum TransportManagerEvent { - /// Connection closed to remote peer. - ConnectionClosed { - /// Peer ID. - peer: PeerId, - - /// Connection ID. - connection: ConnectionId, - }, + /// Connection closed to remote peer. + ConnectionClosed { + /// Peer ID. + peer: PeerId, + + /// Connection ID. + connection: ConnectionId, + }, } // Protocol context. #[derive(Debug, Clone)] pub struct ProtocolContext { - /// Codec used by the protocol. - pub codec: ProtocolCodec, + /// Codec used by the protocol. + pub codec: ProtocolCodec, - /// TX channel for sending events to protocol. - pub tx: Sender, + /// TX channel for sending events to protocol. + pub tx: Sender, - /// Fallback names for the protocol. - pub fallback_names: Vec, + /// Fallback names for the protocol. + pub fallback_names: Vec, } impl ProtocolContext { - /// Create new [`ProtocolContext`]. - fn new( - codec: ProtocolCodec, - tx: Sender, - fallback_names: Vec, - ) -> Self { - Self { tx, codec, fallback_names } - } + /// Create new [`ProtocolContext`]. + fn new( + codec: ProtocolCodec, + tx: Sender, + fallback_names: Vec, + ) -> Self { + Self { + tx, + codec, + fallback_names, + } + } } /// Transport context for enabled transports. struct TransportContext { - /// Polling index. - index: usize, + /// Polling index. + index: usize, - /// Registered transports. - transports: IndexMap>>, + /// Registered transports. + transports: IndexMap>>, } impl TransportContext { - /// Create new [`TransportContext`]. - pub fn new() -> Self { - Self { index: 0usize, transports: IndexMap::new() } - } - - /// Get an iterator of supported transports. - pub fn keys(&self) -> impl Iterator { - self.transports.keys() - } - - /// Get mutable access to transport. - pub fn get_mut( - &mut self, - key: &SupportedTransport, - ) -> Option<&mut Box>> { - self.transports.get_mut(key) - } - - /// Register `transport` to `TransportContext`. - pub fn register_transport( - &mut self, - name: SupportedTransport, - transport: Box>, - ) { - assert!(self.transports.insert(name, transport).is_none()); - } + /// Create new [`TransportContext`]. + pub fn new() -> Self { + Self { + index: 0usize, + transports: IndexMap::new(), + } + } + + /// Get an iterator of supported transports. + pub fn keys(&self) -> impl Iterator { + self.transports.keys() + } + + /// Get mutable access to transport. + pub fn get_mut( + &mut self, + key: &SupportedTransport, + ) -> Option<&mut Box>> { + self.transports.get_mut(key) + } + + /// Register `transport` to `TransportContext`. + pub fn register_transport( + &mut self, + name: SupportedTransport, + transport: Box>, + ) { + assert!(self.transports.insert(name, transport).is_none()); + } } impl Stream for TransportContext { - type Item = (SupportedTransport, TransportEvent); - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let len = match self.transports.len() { - 0 => return Poll::Ready(None), - len => len, - }; - let start_index = self.index; - - loop { - let index = self.index % len; - self.index += 1; - - let (key, stream) = self.transports.get_index_mut(index).expect("transport to exist"); - match stream.poll_next_unpin(cx) { - Poll::Pending => {}, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(event)) => return Poll::Ready(Some((*key, event))), - } - - if self.index == start_index + len { - break Poll::Pending; - } - } - } + type Item = (SupportedTransport, TransportEvent); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let len = match self.transports.len() { + 0 => return Poll::Ready(None), + len => len, + }; + let start_index = self.index; + + loop { + let index = self.index % len; + self.index += 1; + + let (key, stream) = self.transports.get_index_mut(index).expect("transport to exist"); + match stream.poll_next_unpin(cx) { + Poll::Pending => {} + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(event)) => return Poll::Ready(Some((*key, event))), + } + + if self.index == start_index + len { + break Poll::Pending; + } + } + } } /// Litep2p connection manager. pub struct TransportManager { - /// Local peer ID. - local_peer_id: PeerId, + /// Local peer ID. + local_peer_id: PeerId, - /// Keypair. - keypair: Keypair, + /// Keypair. + keypair: Keypair, - /// Bandwidth sink. - bandwidth_sink: BandwidthSink, + /// Bandwidth sink. + bandwidth_sink: BandwidthSink, - /// Maximum parallel dial attempts per peer. - max_parallel_dials: usize, + /// Maximum parallel dial attempts per peer. + max_parallel_dials: usize, - /// Installed protocols. - protocols: HashMap, + /// Installed protocols. + protocols: HashMap, - /// All names (main and fallback(s)) of the installed protocols. - protocol_names: HashSet, + /// All names (main and fallback(s)) of the installed protocols. + protocol_names: HashSet, - /// Listen addresses. - listen_addresses: Arc>>, + /// Listen addresses. + listen_addresses: Arc>>, - /// Next connection ID. - next_connection_id: Arc, + /// Next connection ID. + next_connection_id: Arc, - /// Next substream ID. - next_substream_id: Arc, + /// Next substream ID. + next_substream_id: Arc, - /// Installed transports. - transports: TransportContext, + /// Installed transports. + transports: TransportContext, - /// Peers - peers: Arc>>, + /// Peers + peers: Arc>>, - /// Handle to [`crate::transport::manager::TransportManager`]. - transport_manager_handle: TransportManagerHandle, + /// Handle to [`crate::transport::manager::TransportManager`]. + transport_manager_handle: TransportManagerHandle, - /// RX channel for receiving events from installed transports. - event_rx: Receiver, + /// RX channel for receiving events from installed transports. + event_rx: Receiver, - /// RX channel for receiving commands from installed protocols. - cmd_rx: Receiver, + /// RX channel for receiving commands from installed protocols. + cmd_rx: Receiver, - /// TX channel for transport events that is given to installed transports. - event_tx: Sender, + /// TX channel for transport events that is given to installed transports. + event_tx: Sender, - /// Pending connections. - pending_connections: HashMap, + /// Pending connections. + pending_connections: HashMap, } impl TransportManager { - /// Create new [`crate::transport::manager::TransportManager`]. - // TODO: don't return handle here - pub fn new( - keypair: Keypair, - supported_transports: HashSet, - bandwidth_sink: BandwidthSink, - max_parallel_dials: usize, - ) -> (Self, TransportManagerHandle) { - let local_peer_id = PeerId::from_public_key(&keypair.public().into()); - let peers = Arc::new(RwLock::new(HashMap::new())); - let (cmd_tx, cmd_rx) = channel(256); - let (event_tx, event_rx) = channel(256); - let listen_addresses = Arc::new(RwLock::new(HashSet::new())); - let handle = TransportManagerHandle::new( - local_peer_id, - peers.clone(), - cmd_tx, - supported_transports, - Arc::clone(&listen_addresses), - ); - - ( - Self { - peers, - cmd_rx, - keypair, - event_tx, - event_rx, - local_peer_id, - bandwidth_sink, - listen_addresses, - max_parallel_dials, - protocols: HashMap::new(), - transports: TransportContext::new(), - protocol_names: HashSet::new(), - transport_manager_handle: handle.clone(), - pending_connections: HashMap::new(), - next_substream_id: Arc::new(AtomicUsize::new(0usize)), - next_connection_id: Arc::new(AtomicUsize::new(0usize)), - }, - handle, - ) - } - - /// Get iterator to installed protocols. - pub fn protocols(&self) -> impl Iterator { - self.protocols.keys() - } - - /// Get iterator to installed transports - pub fn installed_transports(&self) -> impl Iterator { - self.transports.keys() - } - - /// Get next connection ID. - fn next_connection_id(&mut self) -> ConnectionId { - let connection_id = self.next_connection_id.fetch_add(1usize, Ordering::Relaxed); - - ConnectionId::from(connection_id) - } - - /// Register protocol to the [`crate::transport::manager::TransportManager`]. - /// - /// This allocates new context for the protocol and returns a handle - /// which the protocol can use the interact with the transport subsystem. - pub fn register_protocol( - &mut self, - protocol: ProtocolName, - fallback_names: Vec, - codec: ProtocolCodec, - ) -> TransportService { - assert!(!self.protocol_names.contains(&protocol)); - - for fallback in &fallback_names { - if self.protocol_names.contains(fallback) { - panic!("duplicate fallback protocol given: {fallback:?}"); - } - } - - let (service, sender) = TransportService::new( - self.local_peer_id, - protocol.clone(), - fallback_names.clone(), - self.next_substream_id.clone(), - self.transport_manager_handle.clone(), - ); - - self.protocols - .insert(protocol.clone(), ProtocolContext::new(codec, sender, fallback_names.clone())); - self.protocol_names.insert(protocol); - self.protocol_names.extend(fallback_names); - - service - } - - /// Acquire `TransportHandle`. - pub fn transport_handle(&self, executor: Arc) -> TransportHandle { - TransportHandle { - tx: self.event_tx.clone(), - executor, - keypair: self.keypair.clone(), - protocols: self.protocols.clone(), - bandwidth_sink: self.bandwidth_sink.clone(), - protocol_names: self.protocol_names.iter().cloned().collect(), - next_substream_id: self.next_substream_id.clone(), - next_connection_id: self.next_connection_id.clone(), - } - } - - /// Register transport to `TransportManager`. - pub(crate) fn register_transport( - &mut self, - name: SupportedTransport, - transport: Box>, - ) { - tracing::debug!(target: LOG_TARGET, transport = ?name, "register transport"); - - self.transports.register_transport(name, transport); - self.transport_manager_handle.register_transport(name); - } - - /// Register local listen address. - pub fn register_listen_address(&mut self, address: Multiaddr) { - assert!(!address.iter().any(|protocol| std::matches!(protocol, Protocol::P2p(_)))); - - let mut listen_addresses = self.listen_addresses.write(); - - listen_addresses.insert(address.clone()); - listen_addresses.insert( - address.with(Protocol::P2p( - Multihash::from_bytes(&self.local_peer_id.to_bytes()).unwrap(), - )), - ); - } - - /// Add one or more known addresses for `peer`. - pub fn add_known_address( - &mut self, - peer: PeerId, - address: impl Iterator, - ) -> usize { - self.transport_manager_handle.add_known_address(&peer, address) - } - - /// Dial peer using `PeerId`. - /// - /// Returns an error if the peer is unknown or the peer is already connected. - pub async fn dial(&mut self, peer: PeerId) -> crate::Result<()> { - if peer == self.local_peer_id { - return Err(Error::TriedToDialSelf); - } - let mut peers = self.peers.write(); - - // if the peer is disconnected, return its context - // - // otherwise set the state back what it was and return dial status to caller - let PeerContext { state, secondary_connection, mut addresses } = match peers.remove(&peer) { - None => return Err(Error::PeerDoesntExist(peer)), - Some(context @ PeerContext { state: PeerState::Connected { .. }, .. }) => { - peers.insert(peer, context); - return Err(Error::AlreadyConnected); - }, - Some( - context @ PeerContext { - state: PeerState::Dialing { .. } | PeerState::Opening { .. }, - .. - }, - ) => { - peers.insert(peer, context); - return Ok(()); - }, - Some(context) => context, - }; - - if let PeerState::Disconnected { dial_record: Some(_) } = &state { - tracing::debug!( - target: LOG_TARGET, - ?peer, - "peer is aready being dialed", - ); - - peers.insert(peer, PeerContext { state, secondary_connection, addresses }); - - return Ok(()); - } - - let mut records: HashMap<_, _> = addresses - .take(self.max_parallel_dials) - .into_iter() - .map(|record| (record.address().clone(), record)) - .collect(); - - if records.is_empty() { - return Err(Error::NoAddressAvailable(peer)); - } - - for (_, record) in &records { - if self.listen_addresses.read().contains(record.as_ref()) { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?record, - "tried to dial self", - ); - - debug_assert!(false); - return Err(Error::TriedToDialSelf); - } - } - - // set connection id for the address record and put peer into `Opening` state - let connection_id = - ConnectionId::from(self.next_connection_id.fetch_add(1usize, Ordering::Relaxed)); - - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - addresses = ?records, - "dial remote peer", - ); - - let mut transports = HashSet::new(); - let mut websocket = Vec::new(); - let mut quic = Vec::new(); - let mut tcp = Vec::new(); - - for (address, record) in &mut records { - record.set_connection_id(connection_id); - - let mut iter = address.iter(); - match iter.find(|protocol| std::matches!(protocol, Protocol::QuicV1)) { - Some(_) => { - quic.push(address.clone()); - transports.insert(SupportedTransport::Quic); - }, - _ => match address - .iter() - .find(|protocol| std::matches!(protocol, Protocol::Ws(_) | Protocol::Wss(_))) - { - Some(_) => { - websocket.push(address.clone()); - transports.insert(SupportedTransport::WebSocket); - }, - None => { - tcp.push(address.clone()); - transports.insert(SupportedTransport::Tcp); - }, - }, - } - } - - peers.insert( - peer, - PeerContext { - state: PeerState::Opening { records, connection_id, transports }, - secondary_connection, - addresses, - }, - ); - - if !tcp.is_empty() { - self.transports - .get_mut(&SupportedTransport::Tcp) - .expect("transport to be supported") - .open(connection_id, tcp)?; - } - - if !quic.is_empty() { - self.transports - .get_mut(&SupportedTransport::Quic) - .expect("transport to be supported") - .open(connection_id, quic)?; - } - - if !websocket.is_empty() { - self.transports - .get_mut(&SupportedTransport::WebSocket) - .expect("transport to be supported") - .open(connection_id, websocket)?; - } - - self.pending_connections.insert(connection_id, peer); - - Ok(()) - } - - /// Dial peer using `Multiaddr`. - /// - /// Returns an error if address it not valid. - pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { - let mut record = AddressRecord::from_multiaddr(address) - .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; - - if self.listen_addresses.read().contains(record.as_ref()) { - return Err(Error::TriedToDialSelf); - } - - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "dial remote peer over address"); - - let mut protocol_stack = record.as_ref().iter(); - match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? - { - Protocol::Ip4(_) | Protocol::Ip6(_) => {}, - Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) => {}, - transport => { - tracing::error!( - target: LOG_TARGET, - ?transport, - "invalid transport, expected `ip4`/`ip6`" - ); - return Err(Error::TransportNotSupported(record.address().clone())); - }, - }; - - let supported_transport = match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? - { - Protocol::Tcp(_) => match protocol_stack.next() { - Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) => SupportedTransport::WebSocket, - Some(Protocol::P2p(_)) => SupportedTransport::Tcp, - _ => return Err(Error::TransportNotSupported(record.address().clone())), - }, - Protocol::Udp(_) => match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? - { - Protocol::QuicV1 => SupportedTransport::Quic, - _ => { - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "expected `quic-v1`"); - return Err(Error::TransportNotSupported(record.address().clone())); - }, - }, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `tcp`" - ); - - return Err(Error::TransportNotSupported(record.address().clone())); - }, - }; - - // when constructing `AddressRecord`, `PeerId` was verified to be part of the address - let remote_peer_id = - PeerId::try_from_multiaddr(record.address()).expect("`PeerId` to exist"); - - // set connection id for the address record and put peer into `Dialing` state - let connection_id = self.next_connection_id(); - record.set_connection_id(connection_id); - - { - let mut peers = self.peers.write(); - - match peers.get_mut(&remote_peer_id) { - None => { - drop(peers); - self.peers.write().insert( - remote_peer_id, - PeerContext { - state: PeerState::Dialing { record: record.clone() }, - addresses: AddressStore::new(), - secondary_connection: None, - }, - ); - }, - Some(PeerContext { - state: - PeerState::Dialing { .. } | - PeerState::Connected { .. } | - PeerState::Opening { .. }, - .. - }) => return Ok(()), - Some(PeerContext { ref mut state, .. }) => { - // TODO: verify that the address is not in `addresses` already - // addresses.insert(address.clone()); - *state = PeerState::Dialing { record: record.clone() }; - }, - } - } - - self.transports - .get_mut(&supported_transport) - .ok_or(Error::TransportNotSupported(record.address().clone()))? - .dial(connection_id, record.address().clone())?; - self.pending_connections.insert(connection_id, remote_peer_id); - - Ok(()) - } - - /// Handle dial failure. - fn on_dial_failure(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let peer = self.pending_connections.remove(&connection_id).ok_or_else(|| { - tracing::error!( - target: LOG_TARGET, - ?connection_id, - "dial failed for a connection that doesn't exist", - ); - debug_assert!(false); - - Error::InvalidState - })?; - - let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { - tracing::error!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "dial failed for a peer that doens't exist", - ); - debug_assert!(false); - - Error::InvalidState - })?; - - match std::mem::replace(&mut context.state, PeerState::Disconnected { dial_record: None }) { - PeerState::Dialing { ref mut record } => { - debug_assert_eq!(record.connection_id(), &Some(connection_id)); - - record.update_score(SCORE_DIAL_FAILURE); - context.addresses.insert(record.clone()); - - context.state = PeerState::Disconnected { dial_record: None }; - Ok(()) - }, - PeerState::Opening { .. } => { - todo!(); - }, - PeerState::Connected { record, dial_record: Some(mut dial_record) } => { - dial_record.update_score(SCORE_DIAL_FAILURE); - context.addresses.insert(dial_record); - - context.state = PeerState::Connected { record, dial_record: None }; - Ok(()) - }, - PeerState::Disconnected { dial_record: Some(mut dial_record) } => { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?dial_record, - "dial failed for a disconnected peer", - ); - - dial_record.update_score(SCORE_DIAL_FAILURE); - context.addresses.insert(dial_record); - - Ok(()) - }, - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "invalid state for dial failure", - ); - context.state = state; - - debug_assert!(false); - Ok(()) - }, - } - } - - /// Handle closed connection. - /// - /// Returns `bool` which indicates whether the event should be returned or not. - fn on_connection_closed( - &mut self, - peer: PeerId, - connection_id: ConnectionId, - ) -> crate::Result> { - let mut peers = self.peers.write(); - let Some(context) = peers.get_mut(&peer) else { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "cannot handle closed connection: peer doesn't exist", - ); - debug_assert!(false); - return Err(Error::PeerDoesntExist(peer)); - }; - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "connection closed", - ); - - match std::mem::replace(&mut context.state, PeerState::Disconnected { dial_record: None }) { - PeerState::Connected { record, dial_record: actual_dial_record } => match record - .connection_id() == - &Some(connection_id) - { - // primary connection was closed - // - // if secondary connection exists, switch to using it while keeping peer in - // `Connected` state and if there's only one connection, set peer - // state to `Disconnected` - true => match context.secondary_connection.take() { - None => { - context.addresses.insert(record); - context.state = PeerState::Disconnected { dial_record: actual_dial_record }; - - return Ok(Some(TransportEvent::ConnectionClosed { peer, connection_id })); - }, - Some(secondary_connection) => { - context.addresses.insert(record); - context.state = PeerState::Connected { - record: secondary_connection, - dial_record: actual_dial_record, - }; - - return Ok(None); - }, - }, - // secondary connection was closed - false => match context.secondary_connection.take() { - Some(secondary_connection) => { - if secondary_connection.connection_id() != &Some(connection_id) { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "unknown connection was closed, potentially ignored tertiary connection", - ); - - context.secondary_connection = Some(secondary_connection); - context.state = - PeerState::Connected { record, dial_record: actual_dial_record }; - - return Ok(None); - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "secondary connection closed", - ); - - context.addresses.insert(secondary_connection); - context.state = - PeerState::Connected { record, dial_record: actual_dial_record }; - return Ok(None); - }, - None => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "non-primary connection was closed but secondary connection doesn't exist", - ); - - debug_assert!(false); - return Err(Error::InvalidState); - }, - }, - }, - PeerState::Disconnected { dial_record } => match context.secondary_connection.take() { - Some(record) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?record, - ?dial_record, - "peer is disconnected but secondary connection exists", - ); - - debug_assert!(false); - context.state = PeerState::Disconnected { dial_record }; - return Err(Error::InvalidState); - }, - None => { - context.state = PeerState::Disconnected { dial_record }; - - Ok(Some(TransportEvent::ConnectionClosed { peer, connection_id })) - }, - }, - state => { - tracing::warn!(target: LOG_TARGET, ?peer, ?connection_id, ?state, "invalid state for a closed connection"); - debug_assert!(false); - return Err(Error::InvalidState); - }, - } - } - - fn on_connection_established( - &mut self, - peer: PeerId, - endpoint: &Endpoint, - ) -> crate::Result { - if let Some(dialed_peer) = self.pending_connections.remove(&endpoint.connection_id()) { - if dialed_peer != peer { - tracing::warn!( - target: LOG_TARGET, - ?dialed_peer, - ?peer, - ?endpoint, - "peer ids do not match but transport was supposed to reject connection" - ); - debug_assert!(false); - return Err(Error::InvalidState); - } - }; - - let mut peers = self.peers.write(); - match peers.get_mut(&peer) { - Some(context) => match context.state { - PeerState::Connected { .. } => match context.secondary_connection { - Some(_) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - "secondary connection already exists, ignoring connection", - ); - - // insert address into the store only if we're the dialer - // - // if we're the listener, remote might have dialed with an ephemeral port - // which it might not be listening, making this address useless - if endpoint.is_listener() { - context.addresses.insert(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - None, - )) - } - - return Ok(ConnectionEstablishedResult::Reject); - }, - None => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - address = ?endpoint.address(), - "secondary connection", - ); - - context.secondary_connection = Some(AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - Some(endpoint.connection_id()), - )); - }, - }, - PeerState::Dialing { ref record, .. } => { - match record.connection_id() == &Some(endpoint.connection_id()) { - true => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - ?record, - "connection opened to remote", - ); - - context.state = - PeerState::Connected { record: record.clone(), dial_record: None }; - }, - false => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - "connection opened by remote while local node was dialing", - ); - - context.state = PeerState::Connected { - record: AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - Some(endpoint.connection_id()), - ), - dial_record: Some(record.clone()), - }; - }, - } - }, - PeerState::Opening { ref mut records, connection_id, ref transports } => { - debug_assert!(std::matches!(endpoint, &Endpoint::Listener { .. })); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - dial_connection_id = ?connection_id, - dial_records = ?records, - dial_transports = ?transports, - listener_endpoint = ?endpoint, - "inbound connection while opening an outbound connection", - ); - - // cancel all pending dials - transports.iter().for_each(|transport| { - self.transports - .get_mut(transport) - .expect("transport to exist") - .cancel(connection_id); - }); - - // since an inbound connection was removed, the outbound connection can be - // removed from pendind dials - // - // all records have the same `ConnectionId` so it doens't matter which of them - // is used to remove the pending dial - self.pending_connections.remove( - &records - .iter() - .next() - .expect("record to exist") - .1 - .connection_id() - .expect("`ConnectionId` to exist"), - ); - - let record = match records.remove(endpoint.address()) { - Some(mut record) => { - record.update_score(SCORE_DIAL_SUCCESS); - record.set_connection_id(endpoint.connection_id()); - record - }, - None => AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - Some(endpoint.connection_id()), - ), - }; - context.addresses.extend(records.iter().map(|(_, record)| record)); - - context.state = PeerState::Connected { record, dial_record: None }; - }, - PeerState::Disconnected { ref mut dial_record } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - connection_id = ?endpoint.connection_id(), - ?endpoint, - ?dial_record, - "connection opened by remote or delayed dial succeeded", - ); - - let (record, dial_record) = match dial_record.take() { - Some(mut dial_record) => - if dial_record.address() == endpoint.address() { - dial_record.set_connection_id(endpoint.connection_id()); - (dial_record, None) - } else { - ( - AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - Some(endpoint.connection_id()), - ), - Some(dial_record), - ) - }, - None => ( - AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - Some(endpoint.connection_id()), - ), - None, - ), - }; - - context.state = PeerState::Connected { record, dial_record }; - }, - }, - None => { - peers.insert( - peer, - PeerContext { - state: PeerState::Connected { - record: AddressRecord::new( - &peer, - endpoint.address().clone(), - SCORE_DIAL_SUCCESS, - Some(endpoint.connection_id()), - ), - dial_record: None, - }, - addresses: AddressStore::new(), - secondary_connection: None, - }, - ); - }, - } - - Ok(ConnectionEstablishedResult::Accept) - } - - fn on_connection_opened( - &mut self, - transport: SupportedTransport, - connection_id: ConnectionId, - address: Multiaddr, - ) -> crate::Result<()> { - let Some(peer) = self.pending_connections.remove(&connection_id) else { - tracing::warn!( - target: LOG_TARGET, - ?connection_id, - ?transport, - ?address, - "connection opened but dial record doesn't exist", - ); - - debug_assert!(false); - return Err(Error::InvalidState); - }; - - let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "connection opened but peer doesn't exist", - ); - - debug_assert!(false); - Error::InvalidState - })?; - - match std::mem::replace(&mut context.state, PeerState::Disconnected { dial_record: None }) { - PeerState::Opening { mut records, connection_id, transports } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?address, - ?transport, - "connection opened to peer", - ); - - // cancel open attempts for other transports as connection already exists - for transport in transports.iter() { - let _ = self - .transports - .get_mut(&transport) - .expect("transport to exist") - .cancel(connection_id); - } - - // set peer state to `Dialing` to signal that the connection is fully opening - // - // set the succeeded `AddressRecord` as the one that is used for dialing and move - // all other address records back to `AddressStore`. and ask - // transport to negotiate the - let mut dial_record = records.remove(&address).expect("address to exist"); - dial_record.update_score(SCORE_DIAL_SUCCESS); - - // negotiate the connection - match self - .transports - .get_mut(&transport) - .expect("transport to exist") - .negotiate(connection_id) - { - Ok(()) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?dial_record, - ?transport, - "negotiation started" - ); - - self.pending_connections.insert(connection_id, peer); - - context.state = PeerState::Dialing { record: dial_record }; - - for (_, record) in records { - context.addresses.insert(record); - } - - Ok(()) - }, - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?error, - "failed to negotiate connection", - ); - context.state = PeerState::Disconnected { dial_record: None }; - - debug_assert!(false); - Err(Error::InvalidState) - }, - } - }, - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "connection opened but `PeerState` is not `Opening`", - ); - context.state = state; - - debug_assert!(false); - Err(Error::InvalidState) - }, - } - } - - /// Handle open failure for dialing attempt for `transport` - fn on_open_failure( - &mut self, - transport: SupportedTransport, - connection_id: ConnectionId, - ) -> crate::Result> { - let Some(peer) = self.pending_connections.remove(&connection_id) else { - tracing::warn!( - target: LOG_TARGET, - ?connection_id, - "open failure but dial record doesn't exist", - ); - - debug_assert!(false); - return Err(Error::InvalidState); - }; - - let mut peers = self.peers.write(); - let context = peers.get_mut(&peer).ok_or_else(|| { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "open failure but peer doesn't exist", - ); - - debug_assert!(false); - Error::InvalidState - })?; - - match std::mem::replace(&mut context.state, PeerState::Disconnected { dial_record: None }) { - PeerState::Opening { records, connection_id, mut transports } => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?transport, - "open failure for peer", - ); - transports.remove(&transport); - - if transports.is_empty() { - for (_, mut record) in records { - record.update_score(SCORE_DIAL_FAILURE); - context.addresses.insert(record); - } - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - "open failure for last transport", - ); - - return Ok(Some(peer)); - } - - self.pending_connections.insert(connection_id, peer); - context.state = PeerState::Opening { records, connection_id, transports }; - - Ok(None) - }, - state => { - tracing::warn!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?state, - "open failure but `PeerState` is not `Opening`", - ); - context.state = state; - - debug_assert!(false); - Err(Error::InvalidState) - }, - } - } - - /// Poll next event from [`crate::transport::manager::TransportManager`]. - pub async fn next(&mut self) -> Option { - loop { - tokio::select! { - event = self.event_rx.recv() => match event? { - TransportManagerEvent::ConnectionClosed { - peer, - connection: connection_id, - } => match self.on_connection_closed(peer, connection_id) { - Ok(None) => {} - Ok(Some(event)) => return Some(event), - Err(error) => tracing::error!( - target: LOG_TARGET, - ?error, - "failed to handle closed connection", - ), - } - }, - command = self.cmd_rx.recv() => match command? { - InnerTransportManagerCommand::DialPeer { peer } => { - if let Err(error) = self.dial(peer).await { - tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to dial peer") - } - } - InnerTransportManagerCommand::DialAddress { address } => { - if let Err(error) = self.dial_address(address).await { - tracing::debug!(target: LOG_TARGET, ?error, "failed to dial peer") - } - } - }, - event = self.transports.next() => { - let (transport, event) = event?; - - match event { - TransportEvent::DialFailure { connection_id, address, error } => { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?address, - ?error, - "failed to dial peer", - ); - - if let Ok(()) = self.on_dial_failure(connection_id) { - match address.iter().last() { - Some(Protocol::P2p(hash)) => match PeerId::from_multihash(hash) { - Ok(peer) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?error, - ?address, - num_protocols = self.protocols.len(), - "dial failure, notify protocols", - ); - - for (protocol, context) in &self.protocols { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?error, - ?address, - ?protocol, - "dial failure, notify protocol", - ); - match context.tx.try_send(InnerTransportEvent::DialFailure { - peer, - address: address.clone(), - }) { - Ok(()) => {} - Err(_) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?error, - ?address, - ?protocol, - "dial failure, channel to protocol clogged, use await", - ); - let _ = context - .tx - .send(InnerTransportEvent::DialFailure { - peer, - address: address.clone(), - }) - .await; - } - } - } - - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?error, - ?address, - "all protocols notified", - ); - } - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?address, - ?connection_id, - ?error, - "failed to parse `PeerId` from `Multiaddr`", - ); - debug_assert!(false); - } - }, - _ => { - tracing::warn!(target: LOG_TARGET, ?address, ?connection_id, "address doesn't contain `PeerId`"); - debug_assert!(false); - } - } - - return Some(TransportEvent::DialFailure { - connection_id, - address, - error, - }) - } - } - TransportEvent::ConnectionEstablished { peer, endpoint } => { - match self.on_connection_established(peer, &endpoint) { - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?peer, - ?endpoint, - ?error, - "failed to handle established connection", - ); - - let _ = self - .transports - .get_mut(&transport) - .expect("transport to exist") - .reject(endpoint.connection_id()); - } - Ok(ConnectionEstablishedResult::Accept) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?endpoint, - "accept connection", - ); - - let _ = self - .transports - .get_mut(&transport) - .expect("transport to exist") - .accept(endpoint.connection_id()); - - return Some(TransportEvent::ConnectionEstablished { - peer, - endpoint: endpoint, - }); - } - Ok(ConnectionEstablishedResult::Reject) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?endpoint, - "reject connection", - ); - - let _ = self - .transports - .get_mut(&transport) - .expect("transport to exist") - .reject(endpoint.connection_id()); - } - } - } - TransportEvent::ConnectionOpened { connection_id, address } => { - if let Err(error) = self.on_connection_opened(transport, connection_id, address) { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "failed to handle opened connection", - ); - } - } - TransportEvent::OpenFailure { connection_id } => { - match self.on_open_failure(transport, connection_id) { - Err(error) => tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "failed to handle opened connection", - ), - Ok(Some(peer)) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - num_protocols = self.protocols.len(), - "inform protocols about open failure", - ); - - for (protocol, context) in &self.protocols { - let _ = match context - .tx - .try_send(InnerTransportEvent::DialFailure { - peer, - address: Multiaddr::empty(), - }) { - Ok(_) => Ok(()), - Err(_) => { - tracing::trace!( - target: LOG_TARGET, - ?peer, - %protocol, - ?connection_id, - "call to protocol would, block try sending in a blocking way", - ); - - context - .tx - .send(InnerTransportEvent::DialFailure { - peer, - address: Multiaddr::empty(), - }) - .await - } - }; - } - - return Some(TransportEvent::DialFailure { - connection_id, - address: Multiaddr::empty(), - error: Error::Unknown, - }) - } - Ok(None) => {} - } - } - _ => panic!("event not supported"), - } - }, - } - } - } + /// Create new [`crate::transport::manager::TransportManager`]. + // TODO: don't return handle here + pub fn new( + keypair: Keypair, + supported_transports: HashSet, + bandwidth_sink: BandwidthSink, + max_parallel_dials: usize, + ) -> (Self, TransportManagerHandle) { + let local_peer_id = PeerId::from_public_key(&keypair.public().into()); + let peers = Arc::new(RwLock::new(HashMap::new())); + let (cmd_tx, cmd_rx) = channel(256); + let (event_tx, event_rx) = channel(256); + let listen_addresses = Arc::new(RwLock::new(HashSet::new())); + let handle = TransportManagerHandle::new( + local_peer_id, + peers.clone(), + cmd_tx, + supported_transports, + Arc::clone(&listen_addresses), + ); + + ( + Self { + peers, + cmd_rx, + keypair, + event_tx, + event_rx, + local_peer_id, + bandwidth_sink, + listen_addresses, + max_parallel_dials, + protocols: HashMap::new(), + transports: TransportContext::new(), + protocol_names: HashSet::new(), + transport_manager_handle: handle.clone(), + pending_connections: HashMap::new(), + next_substream_id: Arc::new(AtomicUsize::new(0usize)), + next_connection_id: Arc::new(AtomicUsize::new(0usize)), + }, + handle, + ) + } + + /// Get iterator to installed protocols. + pub fn protocols(&self) -> impl Iterator { + self.protocols.keys() + } + + /// Get iterator to installed transports + pub fn installed_transports(&self) -> impl Iterator { + self.transports.keys() + } + + /// Get next connection ID. + fn next_connection_id(&mut self) -> ConnectionId { + let connection_id = self.next_connection_id.fetch_add(1usize, Ordering::Relaxed); + + ConnectionId::from(connection_id) + } + + /// Register protocol to the [`crate::transport::manager::TransportManager`]. + /// + /// This allocates new context for the protocol and returns a handle + /// which the protocol can use the interact with the transport subsystem. + pub fn register_protocol( + &mut self, + protocol: ProtocolName, + fallback_names: Vec, + codec: ProtocolCodec, + ) -> TransportService { + assert!(!self.protocol_names.contains(&protocol)); + + for fallback in &fallback_names { + if self.protocol_names.contains(fallback) { + panic!("duplicate fallback protocol given: {fallback:?}"); + } + } + + let (service, sender) = TransportService::new( + self.local_peer_id, + protocol.clone(), + fallback_names.clone(), + self.next_substream_id.clone(), + self.transport_manager_handle.clone(), + ); + + self.protocols.insert( + protocol.clone(), + ProtocolContext::new(codec, sender, fallback_names.clone()), + ); + self.protocol_names.insert(protocol); + self.protocol_names.extend(fallback_names); + + service + } + + /// Acquire `TransportHandle`. + pub fn transport_handle(&self, executor: Arc) -> TransportHandle { + TransportHandle { + tx: self.event_tx.clone(), + executor, + keypair: self.keypair.clone(), + protocols: self.protocols.clone(), + bandwidth_sink: self.bandwidth_sink.clone(), + protocol_names: self.protocol_names.iter().cloned().collect(), + next_substream_id: self.next_substream_id.clone(), + next_connection_id: self.next_connection_id.clone(), + } + } + + /// Register transport to `TransportManager`. + pub(crate) fn register_transport( + &mut self, + name: SupportedTransport, + transport: Box>, + ) { + tracing::debug!(target: LOG_TARGET, transport = ?name, "register transport"); + + self.transports.register_transport(name, transport); + self.transport_manager_handle.register_transport(name); + } + + /// Register local listen address. + pub fn register_listen_address(&mut self, address: Multiaddr) { + assert!(!address.iter().any(|protocol| std::matches!(protocol, Protocol::P2p(_)))); + + let mut listen_addresses = self.listen_addresses.write(); + + listen_addresses.insert(address.clone()); + listen_addresses.insert(address.with(Protocol::P2p( + Multihash::from_bytes(&self.local_peer_id.to_bytes()).unwrap(), + ))); + } + + /// Add one or more known addresses for `peer`. + pub fn add_known_address( + &mut self, + peer: PeerId, + address: impl Iterator, + ) -> usize { + self.transport_manager_handle.add_known_address(&peer, address) + } + + /// Dial peer using `PeerId`. + /// + /// Returns an error if the peer is unknown or the peer is already connected. + pub async fn dial(&mut self, peer: PeerId) -> crate::Result<()> { + if peer == self.local_peer_id { + return Err(Error::TriedToDialSelf); + } + let mut peers = self.peers.write(); + + // if the peer is disconnected, return its context + // + // otherwise set the state back what it was and return dial status to caller + let PeerContext { + state, + secondary_connection, + mut addresses, + } = match peers.remove(&peer) { + None => return Err(Error::PeerDoesntExist(peer)), + Some( + context @ PeerContext { + state: PeerState::Connected { .. }, + .. + }, + ) => { + peers.insert(peer, context); + return Err(Error::AlreadyConnected); + } + Some( + context @ PeerContext { + state: PeerState::Dialing { .. } | PeerState::Opening { .. }, + .. + }, + ) => { + peers.insert(peer, context); + return Ok(()); + } + Some(context) => context, + }; + + if let PeerState::Disconnected { + dial_record: Some(_), + } = &state + { + tracing::debug!( + target: LOG_TARGET, + ?peer, + "peer is aready being dialed", + ); + + peers.insert( + peer, + PeerContext { + state, + secondary_connection, + addresses, + }, + ); + + return Ok(()); + } + + let mut records: HashMap<_, _> = addresses + .take(self.max_parallel_dials) + .into_iter() + .map(|record| (record.address().clone(), record)) + .collect(); + + if records.is_empty() { + return Err(Error::NoAddressAvailable(peer)); + } + + for (_, record) in &records { + if self.listen_addresses.read().contains(record.as_ref()) { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?record, + "tried to dial self", + ); + + debug_assert!(false); + return Err(Error::TriedToDialSelf); + } + } + + // set connection id for the address record and put peer into `Opening` state + let connection_id = + ConnectionId::from(self.next_connection_id.fetch_add(1usize, Ordering::Relaxed)); + + tracing::debug!( + target: LOG_TARGET, + ?connection_id, + addresses = ?records, + "dial remote peer", + ); + + let mut transports = HashSet::new(); + let mut websocket = Vec::new(); + let mut quic = Vec::new(); + let mut tcp = Vec::new(); + + for (address, record) in &mut records { + record.set_connection_id(connection_id); + + let mut iter = address.iter(); + match iter.find(|protocol| std::matches!(protocol, Protocol::QuicV1)) { + Some(_) => { + quic.push(address.clone()); + transports.insert(SupportedTransport::Quic); + } + _ => match address + .iter() + .find(|protocol| std::matches!(protocol, Protocol::Ws(_) | Protocol::Wss(_))) + { + Some(_) => { + websocket.push(address.clone()); + transports.insert(SupportedTransport::WebSocket); + } + None => { + tcp.push(address.clone()); + transports.insert(SupportedTransport::Tcp); + } + }, + } + } + + peers.insert( + peer, + PeerContext { + state: PeerState::Opening { + records, + connection_id, + transports, + }, + secondary_connection, + addresses, + }, + ); + + if !tcp.is_empty() { + self.transports + .get_mut(&SupportedTransport::Tcp) + .expect("transport to be supported") + .open(connection_id, tcp)?; + } + + if !quic.is_empty() { + self.transports + .get_mut(&SupportedTransport::Quic) + .expect("transport to be supported") + .open(connection_id, quic)?; + } + + if !websocket.is_empty() { + self.transports + .get_mut(&SupportedTransport::WebSocket) + .expect("transport to be supported") + .open(connection_id, websocket)?; + } + + self.pending_connections.insert(connection_id, peer); + + Ok(()) + } + + /// Dial peer using `Multiaddr`. + /// + /// Returns an error if address it not valid. + pub async fn dial_address(&mut self, address: Multiaddr) -> crate::Result<()> { + let mut record = AddressRecord::from_multiaddr(address) + .ok_or(Error::AddressError(AddressError::PeerIdMissing))?; + + if self.listen_addresses.read().contains(record.as_ref()) { + return Err(Error::TriedToDialSelf); + } + + tracing::debug!(target: LOG_TARGET, address = ?record.address(), "dial remote peer over address"); + + let mut protocol_stack = record.as_ref().iter(); + match protocol_stack + .next() + .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + { + Protocol::Ip4(_) | Protocol::Ip6(_) => {} + Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) => {} + transport => { + tracing::error!( + target: LOG_TARGET, + ?transport, + "invalid transport, expected `ip4`/`ip6`" + ); + return Err(Error::TransportNotSupported(record.address().clone())); + } + }; + + let supported_transport = match protocol_stack + .next() + .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + { + Protocol::Tcp(_) => match protocol_stack.next() { + Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) => SupportedTransport::WebSocket, + Some(Protocol::P2p(_)) => SupportedTransport::Tcp, + _ => return Err(Error::TransportNotSupported(record.address().clone())), + }, + Protocol::Udp(_) => match protocol_stack + .next() + .ok_or_else(|| Error::TransportNotSupported(record.address().clone()))? + { + Protocol::QuicV1 => SupportedTransport::Quic, + _ => { + tracing::debug!(target: LOG_TARGET, address = ?record.address(), "expected `quic-v1`"); + return Err(Error::TransportNotSupported(record.address().clone())); + } + }, + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `tcp`" + ); + + return Err(Error::TransportNotSupported(record.address().clone())); + } + }; + + // when constructing `AddressRecord`, `PeerId` was verified to be part of the address + let remote_peer_id = + PeerId::try_from_multiaddr(record.address()).expect("`PeerId` to exist"); + + // set connection id for the address record and put peer into `Dialing` state + let connection_id = self.next_connection_id(); + record.set_connection_id(connection_id); + + { + let mut peers = self.peers.write(); + + match peers.get_mut(&remote_peer_id) { + None => { + drop(peers); + self.peers.write().insert( + remote_peer_id, + PeerContext { + state: PeerState::Dialing { + record: record.clone(), + }, + addresses: AddressStore::new(), + secondary_connection: None, + }, + ); + } + Some(PeerContext { + state: + PeerState::Dialing { .. } + | PeerState::Connected { .. } + | PeerState::Opening { .. }, + .. + }) => return Ok(()), + Some(PeerContext { ref mut state, .. }) => { + // TODO: verify that the address is not in `addresses` already + // addresses.insert(address.clone()); + *state = PeerState::Dialing { + record: record.clone(), + }; + } + } + } + + self.transports + .get_mut(&supported_transport) + .ok_or(Error::TransportNotSupported(record.address().clone()))? + .dial(connection_id, record.address().clone())?; + self.pending_connections.insert(connection_id, remote_peer_id); + + Ok(()) + } + + /// Handle dial failure. + fn on_dial_failure(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let peer = self.pending_connections.remove(&connection_id).ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + ?connection_id, + "dial failed for a connection that doesn't exist", + ); + debug_assert!(false); + + Error::InvalidState + })?; + + let mut peers = self.peers.write(); + let context = peers.get_mut(&peer).ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "dial failed for a peer that doens't exist", + ); + debug_assert!(false); + + Error::InvalidState + })?; + + match std::mem::replace( + &mut context.state, + PeerState::Disconnected { dial_record: None }, + ) { + PeerState::Dialing { ref mut record } => { + debug_assert_eq!(record.connection_id(), &Some(connection_id)); + + record.update_score(SCORE_DIAL_FAILURE); + context.addresses.insert(record.clone()); + + context.state = PeerState::Disconnected { dial_record: None }; + Ok(()) + } + PeerState::Opening { .. } => { + todo!(); + } + PeerState::Connected { + record, + dial_record: Some(mut dial_record), + } => { + dial_record.update_score(SCORE_DIAL_FAILURE); + context.addresses.insert(dial_record); + + context.state = PeerState::Connected { + record, + dial_record: None, + }; + Ok(()) + } + PeerState::Disconnected { + dial_record: Some(mut dial_record), + } => { + tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?dial_record, + "dial failed for a disconnected peer", + ); + + dial_record.update_score(SCORE_DIAL_FAILURE); + context.addresses.insert(dial_record); + + Ok(()) + } + state => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?state, + "invalid state for dial failure", + ); + context.state = state; + + debug_assert!(false); + Ok(()) + } + } + } + + /// Handle closed connection. + /// + /// Returns `bool` which indicates whether the event should be returned or not. + fn on_connection_closed( + &mut self, + peer: PeerId, + connection_id: ConnectionId, + ) -> crate::Result> { + let mut peers = self.peers.write(); + let Some(context) = peers.get_mut(&peer) else { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "cannot handle closed connection: peer doesn't exist", + ); + debug_assert!(false); + return Err(Error::PeerDoesntExist(peer)); + }; + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "connection closed", + ); + + match std::mem::replace( + &mut context.state, + PeerState::Disconnected { dial_record: None }, + ) { + PeerState::Connected { + record, + dial_record: actual_dial_record, + } => match record.connection_id() == &Some(connection_id) { + // primary connection was closed + // + // if secondary connection exists, switch to using it while keeping peer in + // `Connected` state and if there's only one connection, set peer + // state to `Disconnected` + true => match context.secondary_connection.take() { + None => { + context.addresses.insert(record); + context.state = PeerState::Disconnected { + dial_record: actual_dial_record, + }; + + return Ok(Some(TransportEvent::ConnectionClosed { + peer, + connection_id, + })); + } + Some(secondary_connection) => { + context.addresses.insert(record); + context.state = PeerState::Connected { + record: secondary_connection, + dial_record: actual_dial_record, + }; + + return Ok(None); + } + }, + // secondary connection was closed + false => match context.secondary_connection.take() { + Some(secondary_connection) => { + if secondary_connection.connection_id() != &Some(connection_id) { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "unknown connection was closed, potentially ignored tertiary connection", + ); + + context.secondary_connection = Some(secondary_connection); + context.state = PeerState::Connected { + record, + dial_record: actual_dial_record, + }; + + return Ok(None); + } + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "secondary connection closed", + ); + + context.addresses.insert(secondary_connection); + context.state = PeerState::Connected { + record, + dial_record: actual_dial_record, + }; + return Ok(None); + } + None => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "non-primary connection was closed but secondary connection doesn't exist", + ); + + debug_assert!(false); + return Err(Error::InvalidState); + } + }, + }, + PeerState::Disconnected { dial_record } => match context.secondary_connection.take() { + Some(record) => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?record, + ?dial_record, + "peer is disconnected but secondary connection exists", + ); + + debug_assert!(false); + context.state = PeerState::Disconnected { dial_record }; + return Err(Error::InvalidState); + } + None => { + context.state = PeerState::Disconnected { dial_record }; + + Ok(Some(TransportEvent::ConnectionClosed { + peer, + connection_id, + })) + } + }, + state => { + tracing::warn!(target: LOG_TARGET, ?peer, ?connection_id, ?state, "invalid state for a closed connection"); + debug_assert!(false); + return Err(Error::InvalidState); + } + } + } + + fn on_connection_established( + &mut self, + peer: PeerId, + endpoint: &Endpoint, + ) -> crate::Result { + if let Some(dialed_peer) = self.pending_connections.remove(&endpoint.connection_id()) { + if dialed_peer != peer { + tracing::warn!( + target: LOG_TARGET, + ?dialed_peer, + ?peer, + ?endpoint, + "peer ids do not match but transport was supposed to reject connection" + ); + debug_assert!(false); + return Err(Error::InvalidState); + } + }; + + let mut peers = self.peers.write(); + match peers.get_mut(&peer) { + Some(context) => match context.state { + PeerState::Connected { .. } => match context.secondary_connection { + Some(_) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + connection_id = ?endpoint.connection_id(), + ?endpoint, + "secondary connection already exists, ignoring connection", + ); + + // insert address into the store only if we're the dialer + // + // if we're the listener, remote might have dialed with an ephemeral port + // which it might not be listening, making this address useless + if endpoint.is_listener() { + context.addresses.insert(AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + None, + )) + } + + return Ok(ConnectionEstablishedResult::Reject); + } + None => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + connection_id = ?endpoint.connection_id(), + address = ?endpoint.address(), + "secondary connection", + ); + + context.secondary_connection = Some(AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + Some(endpoint.connection_id()), + )); + } + }, + PeerState::Dialing { ref record, .. } => { + match record.connection_id() == &Some(endpoint.connection_id()) { + true => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + connection_id = ?endpoint.connection_id(), + ?endpoint, + ?record, + "connection opened to remote", + ); + + context.state = PeerState::Connected { + record: record.clone(), + dial_record: None, + }; + } + false => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + connection_id = ?endpoint.connection_id(), + ?endpoint, + "connection opened by remote while local node was dialing", + ); + + context.state = PeerState::Connected { + record: AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + Some(endpoint.connection_id()), + ), + dial_record: Some(record.clone()), + }; + } + } + } + PeerState::Opening { + ref mut records, + connection_id, + ref transports, + } => { + debug_assert!(std::matches!(endpoint, &Endpoint::Listener { .. })); + + tracing::trace!( + target: LOG_TARGET, + ?peer, + dial_connection_id = ?connection_id, + dial_records = ?records, + dial_transports = ?transports, + listener_endpoint = ?endpoint, + "inbound connection while opening an outbound connection", + ); + + // cancel all pending dials + transports.iter().for_each(|transport| { + self.transports + .get_mut(transport) + .expect("transport to exist") + .cancel(connection_id); + }); + + // since an inbound connection was removed, the outbound connection can be + // removed from pendind dials + // + // all records have the same `ConnectionId` so it doens't matter which of them + // is used to remove the pending dial + self.pending_connections.remove( + &records + .iter() + .next() + .expect("record to exist") + .1 + .connection_id() + .expect("`ConnectionId` to exist"), + ); + + let record = match records.remove(endpoint.address()) { + Some(mut record) => { + record.update_score(SCORE_DIAL_SUCCESS); + record.set_connection_id(endpoint.connection_id()); + record + } + None => AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + Some(endpoint.connection_id()), + ), + }; + context.addresses.extend(records.iter().map(|(_, record)| record)); + + context.state = PeerState::Connected { + record, + dial_record: None, + }; + } + PeerState::Disconnected { + ref mut dial_record, + } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + connection_id = ?endpoint.connection_id(), + ?endpoint, + ?dial_record, + "connection opened by remote or delayed dial succeeded", + ); + + let (record, dial_record) = match dial_record.take() { + Some(mut dial_record) => + if dial_record.address() == endpoint.address() { + dial_record.set_connection_id(endpoint.connection_id()); + (dial_record, None) + } else { + ( + AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + Some(endpoint.connection_id()), + ), + Some(dial_record), + ) + }, + None => ( + AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + Some(endpoint.connection_id()), + ), + None, + ), + }; + + context.state = PeerState::Connected { + record, + dial_record, + }; + } + }, + None => { + peers.insert( + peer, + PeerContext { + state: PeerState::Connected { + record: AddressRecord::new( + &peer, + endpoint.address().clone(), + SCORE_DIAL_SUCCESS, + Some(endpoint.connection_id()), + ), + dial_record: None, + }, + addresses: AddressStore::new(), + secondary_connection: None, + }, + ); + } + } + + Ok(ConnectionEstablishedResult::Accept) + } + + fn on_connection_opened( + &mut self, + transport: SupportedTransport, + connection_id: ConnectionId, + address: Multiaddr, + ) -> crate::Result<()> { + let Some(peer) = self.pending_connections.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?transport, + ?address, + "connection opened but dial record doesn't exist", + ); + + debug_assert!(false); + return Err(Error::InvalidState); + }; + + let mut peers = self.peers.write(); + let context = peers.get_mut(&peer).ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "connection opened but peer doesn't exist", + ); + + debug_assert!(false); + Error::InvalidState + })?; + + match std::mem::replace( + &mut context.state, + PeerState::Disconnected { dial_record: None }, + ) { + PeerState::Opening { + mut records, + connection_id, + transports, + } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?address, + ?transport, + "connection opened to peer", + ); + + // cancel open attempts for other transports as connection already exists + for transport in transports.iter() { + let _ = self + .transports + .get_mut(&transport) + .expect("transport to exist") + .cancel(connection_id); + } + + // set peer state to `Dialing` to signal that the connection is fully opening + // + // set the succeeded `AddressRecord` as the one that is used for dialing and move + // all other address records back to `AddressStore`. and ask + // transport to negotiate the + let mut dial_record = records.remove(&address).expect("address to exist"); + dial_record.update_score(SCORE_DIAL_SUCCESS); + + // negotiate the connection + match self + .transports + .get_mut(&transport) + .expect("transport to exist") + .negotiate(connection_id) + { + Ok(()) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?dial_record, + ?transport, + "negotiation started" + ); + + self.pending_connections.insert(connection_id, peer); + + context.state = PeerState::Dialing { + record: dial_record, + }; + + for (_, record) in records { + context.addresses.insert(record); + } + + Ok(()) + } + Err(error) => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?error, + "failed to negotiate connection", + ); + context.state = PeerState::Disconnected { dial_record: None }; + + debug_assert!(false); + Err(Error::InvalidState) + } + } + } + state => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?state, + "connection opened but `PeerState` is not `Opening`", + ); + context.state = state; + + debug_assert!(false); + Err(Error::InvalidState) + } + } + } + + /// Handle open failure for dialing attempt for `transport` + fn on_open_failure( + &mut self, + transport: SupportedTransport, + connection_id: ConnectionId, + ) -> crate::Result> { + let Some(peer) = self.pending_connections.remove(&connection_id) else { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + "open failure but dial record doesn't exist", + ); + + debug_assert!(false); + return Err(Error::InvalidState); + }; + + let mut peers = self.peers.write(); + let context = peers.get_mut(&peer).ok_or_else(|| { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "open failure but peer doesn't exist", + ); + + debug_assert!(false); + Error::InvalidState + })?; + + match std::mem::replace( + &mut context.state, + PeerState::Disconnected { dial_record: None }, + ) { + PeerState::Opening { + records, + connection_id, + mut transports, + } => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?transport, + "open failure for peer", + ); + transports.remove(&transport); + + if transports.is_empty() { + for (_, mut record) in records { + record.update_score(SCORE_DIAL_FAILURE); + context.addresses.insert(record); + } + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + "open failure for last transport", + ); + + return Ok(Some(peer)); + } + + self.pending_connections.insert(connection_id, peer); + context.state = PeerState::Opening { + records, + connection_id, + transports, + }; + + Ok(None) + } + state => { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?state, + "open failure but `PeerState` is not `Opening`", + ); + context.state = state; + + debug_assert!(false); + Err(Error::InvalidState) + } + } + } + + /// Poll next event from [`crate::transport::manager::TransportManager`]. + pub async fn next(&mut self) -> Option { + loop { + tokio::select! { + event = self.event_rx.recv() => match event? { + TransportManagerEvent::ConnectionClosed { + peer, + connection: connection_id, + } => match self.on_connection_closed(peer, connection_id) { + Ok(None) => {} + Ok(Some(event)) => return Some(event), + Err(error) => tracing::error!( + target: LOG_TARGET, + ?error, + "failed to handle closed connection", + ), + } + }, + command = self.cmd_rx.recv() => match command? { + InnerTransportManagerCommand::DialPeer { peer } => { + if let Err(error) = self.dial(peer).await { + tracing::debug!(target: LOG_TARGET, ?peer, ?error, "failed to dial peer") + } + } + InnerTransportManagerCommand::DialAddress { address } => { + if let Err(error) = self.dial_address(address).await { + tracing::debug!(target: LOG_TARGET, ?error, "failed to dial peer") + } + } + }, + event = self.transports.next() => { + let (transport, event) = event?; + + match event { + TransportEvent::DialFailure { connection_id, address, error } => { + tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?address, + ?error, + "failed to dial peer", + ); + + if let Ok(()) = self.on_dial_failure(connection_id) { + match address.iter().last() { + Some(Protocol::P2p(hash)) => match PeerId::from_multihash(hash) { + Ok(peer) => { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?error, + ?address, + num_protocols = self.protocols.len(), + "dial failure, notify protocols", + ); + + for (protocol, context) in &self.protocols { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?error, + ?address, + ?protocol, + "dial failure, notify protocol", + ); + match context.tx.try_send(InnerTransportEvent::DialFailure { + peer, + address: address.clone(), + }) { + Ok(()) => {} + Err(_) => { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?error, + ?address, + ?protocol, + "dial failure, channel to protocol clogged, use await", + ); + let _ = context + .tx + .send(InnerTransportEvent::DialFailure { + peer, + address: address.clone(), + }) + .await; + } + } + } + + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?error, + ?address, + "all protocols notified", + ); + } + Err(error) => { + tracing::warn!( + target: LOG_TARGET, + ?address, + ?connection_id, + ?error, + "failed to parse `PeerId` from `Multiaddr`", + ); + debug_assert!(false); + } + }, + _ => { + tracing::warn!(target: LOG_TARGET, ?address, ?connection_id, "address doesn't contain `PeerId`"); + debug_assert!(false); + } + } + + return Some(TransportEvent::DialFailure { + connection_id, + address, + error, + }) + } + } + TransportEvent::ConnectionEstablished { peer, endpoint } => { + match self.on_connection_established(peer, &endpoint) { + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?peer, + ?endpoint, + ?error, + "failed to handle established connection", + ); + + let _ = self + .transports + .get_mut(&transport) + .expect("transport to exist") + .reject(endpoint.connection_id()); + } + Ok(ConnectionEstablishedResult::Accept) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?endpoint, + "accept connection", + ); + + let _ = self + .transports + .get_mut(&transport) + .expect("transport to exist") + .accept(endpoint.connection_id()); + + return Some(TransportEvent::ConnectionEstablished { + peer, + endpoint: endpoint, + }); + } + Ok(ConnectionEstablishedResult::Reject) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?endpoint, + "reject connection", + ); + + let _ = self + .transports + .get_mut(&transport) + .expect("transport to exist") + .reject(endpoint.connection_id()); + } + } + } + TransportEvent::ConnectionOpened { connection_id, address } => { + if let Err(error) = self.on_connection_opened(transport, connection_id, address) { + tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "failed to handle opened connection", + ); + } + } + TransportEvent::OpenFailure { connection_id } => { + match self.on_open_failure(transport, connection_id) { + Err(error) => tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "failed to handle opened connection", + ), + Ok(Some(peer)) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + num_protocols = self.protocols.len(), + "inform protocols about open failure", + ); + + for (protocol, context) in &self.protocols { + let _ = match context + .tx + .try_send(InnerTransportEvent::DialFailure { + peer, + address: Multiaddr::empty(), + }) { + Ok(_) => Ok(()), + Err(_) => { + tracing::trace!( + target: LOG_TARGET, + ?peer, + %protocol, + ?connection_id, + "call to protocol would, block try sending in a blocking way", + ); + + context + .tx + .send(InnerTransportEvent::DialFailure { + peer, + address: Multiaddr::empty(), + }) + .await + } + }; + } + + return Some(TransportEvent::DialFailure { + connection_id, + address: Multiaddr::empty(), + error: Error::Unknown, + }) + } + Ok(None) => {} + } + } + _ => panic!("event not supported"), + } + }, + } + } + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - crypto::ed25519::Keypair, executor::DefaultExecutor, transport::dummy::DummyTransport, - }; - use std::{ - net::{Ipv4Addr, Ipv6Addr}, - sync::Arc, - }; - - #[test] - #[should_panic] - #[cfg(debug_assertions)] - fn duplicate_protocol() { - let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); - - manager.register_protocol( - ProtocolName::from("/notif/1"), - Vec::new(), - ProtocolCodec::UnsignedVarint(None), - ); - manager.register_protocol( - ProtocolName::from("/notif/1"), - Vec::new(), - ProtocolCodec::UnsignedVarint(None), - ); - } - - #[test] - #[should_panic] - #[cfg(debug_assertions)] - fn fallback_protocol_as_duplicate_main_protocol() { - let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); - - manager.register_protocol( - ProtocolName::from("/notif/1"), - Vec::new(), - ProtocolCodec::UnsignedVarint(None), - ); - manager.register_protocol( - ProtocolName::from("/notif/2"), - vec![ProtocolName::from("/notif/2/new"), ProtocolName::from("/notif/1")], - ProtocolCodec::UnsignedVarint(None), - ); - } - - #[test] - #[should_panic] - #[cfg(debug_assertions)] - fn duplicate_fallback_protocol() { - let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); - - manager.register_protocol( - ProtocolName::from("/notif/1"), - vec![ProtocolName::from("/notif/1/new"), ProtocolName::from("/notif/1")], - ProtocolCodec::UnsignedVarint(None), - ); - manager.register_protocol( - ProtocolName::from("/notif/2"), - vec![ProtocolName::from("/notif/2/new"), ProtocolName::from("/notif/1/new")], - ProtocolCodec::UnsignedVarint(None), - ); - } - - #[test] - #[should_panic] - #[cfg(debug_assertions)] - fn duplicate_transport() { - let sink = BandwidthSink::new(); - let (mut manager, _handle) = - TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); - - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - } - - #[tokio::test] - async fn tried_to_self_using_peer_id() { - let keypair = Keypair::generate(); - let local_peer_id = PeerId::from_public_key(&keypair.public().into()); - let sink = BandwidthSink::new(); - let (mut manager, _handle) = TransportManager::new(keypair, HashSet::new(), sink, 8usize); - - assert!(manager.dial(local_peer_id).await.is_err()); - } - - #[tokio::test] - async fn try_to_dial_over_disabled_transport() { - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(8888)) - .with(Protocol::QuicV1) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - - assert!(std::matches!( - manager.dial_address(address).await, - Err(Error::TransportNotSupported(_)) - )); - } - - #[tokio::test] - async fn successful_dial_reported_to_transport_manager() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - let transport = Box::new({ - let mut transport = DummyTransport::new(); - transport.inject_event(TransportEvent::ConnectionEstablished { - peer, - endpoint: Endpoint::dialer(dial_address.clone(), ConnectionId::from(0usize)), - }); - transport - }); - manager.register_transport(SupportedTransport::Tcp, transport); - - assert!(manager.dial_address(dial_address.clone()).await.is_ok()); - assert!(!manager.pending_connections.is_empty()); - - { - let peers = manager.peers.read(); - - match peers.get(&peer) { - Some(PeerContext { state: PeerState::Dialing { .. }, .. }) => {}, - state => panic!("invalid state for peer: {state:?}"), - } - } - - match manager.next().await.unwrap() { - TransportEvent::ConnectionEstablished { - peer: event_peer, - endpoint: event_endpoint, - .. - } => { - assert_eq!(peer, event_peer); - assert_eq!( - event_endpoint, - Endpoint::dialer(dial_address.clone(), ConnectionId::from(0usize)) - ) - }, - event => panic!("invalid event: {event:?}"), - } - } - - #[tokio::test] - async fn try_to_dial_same_peer_twice() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - assert!(manager.dial_address(dial_address.clone()).await.is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - - assert!(manager.dial_address(dial_address.clone()).await.is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - } - - #[tokio::test] - async fn try_to_dial_same_peer_twice_diffrent_address() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - - assert!(manager - .dial_address( - Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap(),)) - ) - .await - .is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - - assert!(manager - .dial_address( - Multiaddr::empty() - .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap(),)) - ) - .await - .is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - } - - #[tokio::test] - async fn dial_non_existent_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - assert!(manager.dial(PeerId::random()).await.is_err()); - } - - #[tokio::test] - async fn dial_non_peer_with_no_known_addresses() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - manager.peers.write().insert( - peer, - PeerContext { - state: PeerState::Disconnected { dial_record: None }, - addresses: AddressStore::new(), - secondary_connection: None, - }, - ); - - assert!(manager.dial(peer).await.is_err()); - } - - #[tokio::test] - async fn check_supported_transport_when_adding_known_address() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (_manager, handle) = TransportManager::new( - Keypair::generate(), - HashSet::from_iter([SupportedTransport::Tcp, SupportedTransport::Quic]), - BandwidthSink::new(), - 8usize, - ); - - // ipv6 - let address = Multiaddr::empty() - .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - assert!(handle.supported_transport(&address)); - - // ipv4 - let address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - assert!(handle.supported_transport(&address)); - - // quic - let address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(8888)) - .with(Protocol::QuicV1) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - assert!(handle.supported_transport(&address)); - - // websocket - let address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))); - assert!(!handle.supported_transport(&address)); - - // websocket secure - let address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::Wss(std::borrow::Cow::Owned("/".to_string()))); - assert!(!handle.supported_transport(&address)); - } - - // local node tried to dial a node and it failed but in the mean - // time the remote node dialed local node and that succeeded. - #[tokio::test] - async fn on_dial_failure_already_connected() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let connect_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - assert!(manager.dial_address(dial_address.clone()).await.is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - - match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); - }, - state => panic!("invalid state for peer: {state:?}"), - } - - // remote peer connected to local node from a different address that was dialed - manager - .on_connection_established( - peer, - &Endpoint::dialer(connect_address, ConnectionId::from(1usize)), - ) - .unwrap(); - - // dialing the peer failed - manager.on_dial_failure(ConnectionId::from(0usize)).unwrap(); - - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record, .. } => { - assert!(dial_record.is_none()); - assert!(peer.addresses.contains(&dial_address)); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // local node tried to dial a node and it failed but in the mean - // time the remote node dialed local node and that succeeded. - // - // while the dial was still in progresss, the remote node disconnected after which - // the dial failure was reported. - #[tokio::test] - async fn on_dial_failure_already_connected_and_disconnected() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let connect_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - assert!(manager.dial_address(dial_address.clone()).await.is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - - match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); - }, - state => panic!("invalid state for peer: {state:?}"), - } - - // remote peer connected to local node from a different address that was dialed - manager - .on_connection_established( - peer, - &Endpoint::listener(connect_address, ConnectionId::from(1usize)), - ) - .unwrap(); - - // connection to remote was closed while the dial was still in progress - manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); - - // verify that the peer state is `Disconnected` - { - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Disconnected { dial_record: Some(dial_record), .. } => { - assert_eq!(dial_record.address(), &dial_address); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // dialing the peer failed - manager.on_dial_failure(ConnectionId::from(0usize)).unwrap(); - - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Disconnected { dial_record: None, .. } => { - assert!(peer.addresses.contains(&dial_address)); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // local node tried to dial a node and it failed but in the mean - // time the remote node dialed local node and that succeeded. - // - // while the dial was still in progresss, the remote node disconnected after which - // the dial failure was reported. - #[tokio::test] - async fn on_dial_success_while_connected_and_disconnected() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let connect_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - assert!(manager.dial_address(dial_address.clone()).await.is_ok()); - assert_eq!(manager.pending_connections.len(), 1); - - match &manager.peers.read().get(&peer).unwrap().state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &dial_address); - }, - state => panic!("invalid state for peer: {state:?}"), - } - - // remote peer connected to local node from a different address that was dialed - manager - .on_connection_established( - peer, - &Endpoint::listener(connect_address, ConnectionId::from(1usize)), - ) - .unwrap(); - - // connection to remote was closed while the dial was still in progress - manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); - - // verify that the peer state is `Disconnected` - { - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Disconnected { dial_record: Some(dial_record), .. } => { - assert_eq!(dial_record.address(), &dial_address); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // the original dial succeeded - manager - .on_connection_established( - peer, - &Endpoint::dialer(dial_address, ConnectionId::from(0usize)), - ) - .unwrap(); - - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record: None, .. } => {}, - state => panic!("invalid state: {state:?}"), - } - } - - #[tokio::test] - async fn secondary_connection_is_tracked() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let address1 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let address2 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let address3 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 10, 64))) - .with(Protocol::Tcp(9999)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - // remote peer connected to local node - manager - .on_connection_established( - peer, - &Endpoint::listener(address1, ConnectionId::from(0usize)), - ) - .unwrap(); - - // verify that the peer state is `Connected` with no seconary connection - { - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record: None, .. } => { - assert!(peer.secondary_connection.is_none()); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // second connection is established, verify that the seconary connection is tracked - manager - .on_connection_established( - peer, - &Endpoint::listener(address2.clone(), ConnectionId::from(1usize)), - ) - .unwrap(); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); - }, - state => panic!("invalid state: {state:?}"), - } - drop(peers); - - // tertiary connection is ignored - manager - .on_connection_established( - peer, - &Endpoint::listener(address3.clone(), ConnectionId::from(2usize)), - ) - .unwrap(); - - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record: None, .. } => { - let seconary_connection = peer.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); - assert!(peer.addresses.contains(&address3)); - }, - state => panic!("invalid state: {state:?}"), - } - } - - #[tokio::test] - async fn secondary_connection_closed() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let address1 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let address2 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - // remote peer connected to local node - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::listener(address1, ConnectionId::from(0usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Accept)); - - // verify that the peer state is `Connected` with no seconary connection - { - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record: None, .. } => { - assert!(peer.secondary_connection.is_none()); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // second connection is established, verify that the seconary connection is tracked - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::dialer(address2.clone(), ConnectionId::from(1usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Accept)); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); - }, - state => panic!("invalid state: {state:?}"), - } - drop(peers); - - // close the secondary connection and verify that the peer remains connected - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); - assert!(emit_event.is_none()); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, record } => { - assert!(context.secondary_connection.is_none()); - assert!(context.addresses.contains(&address2)); - assert_eq!(record.connection_id(), &Some(ConnectionId::from(0usize))); - }, - state => panic!("invalid state: {state:?}"), - } - } - - #[tokio::test] - async fn switch_to_secondary_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let address1 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let address2 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - // remote peer connected to local node - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::listener(address1.clone(), ConnectionId::from(0usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Accept)); - - // verify that the peer state is `Connected` with no seconary connection - { - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record: None, .. } => { - assert!(peer.secondary_connection.is_none()); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // second connection is established, verify that the seconary connection is tracked - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::dialer(address2.clone(), ConnectionId::from(1usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Accept)); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); - }, - state => panic!("invalid state: {state:?}"), - } - drop(peers); - - // close the primary connection and verify that the peer remains connected - // while the primary connection address is stored in peer addresses - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(0usize)).unwrap(); - assert!(emit_event.is_none()); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, record } => { - assert!(context.secondary_connection.is_none()); - assert!(context.addresses.contains(&address1)); - assert_eq!(record.connection_id(), &Some(ConnectionId::from(1usize))); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // two connections already exist and a third was opened which is ignored by - // `on_connection_established()`, when that connection is closed, verify that - // it's handled gracefully - #[tokio::test] - async fn tertiary_connection_closed() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); - - let peer = PeerId::random(); - let address1 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let address2 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - let address3 = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) - .with(Protocol::Tcp(9999)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - // remote peer connected to local node - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::listener(address1, ConnectionId::from(0usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Accept)); - - // verify that the peer state is `Connected` with no seconary connection - { - let peers = manager.peers.read(); - let peer = peers.get(&peer).unwrap(); - - match &peer.state { - PeerState::Connected { dial_record: None, .. } => { - assert!(peer.secondary_connection.is_none()); - }, - state => panic!("invalid state: {state:?}"), - } - } - - // second connection is established, verify that the seconary connection is tracked - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::dialer(address2.clone(), ConnectionId::from(1usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Accept)); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); - }, - state => panic!("invalid state: {state:?}"), - } - drop(peers); - - // third connection is established, verify that it's discarded - let emit_event = manager - .on_connection_established( - peer, - &Endpoint::listener(address3.clone(), ConnectionId::from(2usize)), - ) - .unwrap(); - assert!(std::matches!(emit_event, ConnectionEstablishedResult::Reject)); - - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - assert!(context.addresses.contains(&address3)); - drop(peers); - - // close the tertiary connection that was ignored - let emit_event = manager.on_connection_closed(peer, ConnectionId::from(2usize)).unwrap(); - assert!(emit_event.is_none()); - - // verify that the state remains unchanged - let peers = manager.peers.read(); - let context = peers.get(&peer).unwrap(); - - match &context.state { - PeerState::Connected { dial_record: None, .. } => { - let seconary_connection = context.secondary_connection.as_ref().unwrap(); - assert_eq!(seconary_connection.address(), &address2); - assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); - }, - state => panic!("invalid state: {state:?}"), - } - drop(peers); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn dial_failure_for_unknow_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - manager.on_dial_failure(ConnectionId::random()).unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn dial_failure_for_unknow_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let connection_id = ConnectionId::random(); - let peer = PeerId::random(); - manager.pending_connections.insert(connection_id, peer); - manager.on_dial_failure(connection_id).unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn connection_closed_for_unknown_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - manager.on_connection_closed(PeerId::random(), ConnectionId::random()).unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn unknown_connection_opened() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - manager - .on_connection_opened( - SupportedTransport::Tcp, - ConnectionId::random(), - Multiaddr::empty(), - ) - .unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn connection_opened_for_unknown_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let connection_id = ConnectionId::random(); - let peer = PeerId::random(); - - manager.pending_connections.insert(connection_id, peer); - manager - .on_connection_opened(SupportedTransport::Tcp, connection_id, Multiaddr::empty()) - .unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn connection_established_for_wrong_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let connection_id = ConnectionId::random(); - let peer = PeerId::random(); - - manager.pending_connections.insert(connection_id, peer); - manager - .on_connection_established( - PeerId::random(), - &Endpoint::dialer(Multiaddr::empty(), connection_id), - ) - .unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn open_failure_unknown_connection() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - manager - .on_open_failure(SupportedTransport::Tcp, ConnectionId::random()) - .unwrap(); - } - - #[tokio::test] - #[cfg(debug_assertions)] - #[should_panic] - async fn open_failure_unknown_peer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let connection_id = ConnectionId::random(); - let peer = PeerId::random(); - - manager.pending_connections.insert(connection_id, peer); - manager.on_open_failure(SupportedTransport::Tcp, connection_id).unwrap(); - } - - #[tokio::test] - async fn no_transports() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - assert!(manager.next().await.is_none()); - } - - #[tokio::test] - async fn dial_already_connected_peer() { - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - let peer = { - let peer = PeerId::random(); - let mut peers = manager.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Connected { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - dial_record: None, - }, - secondary_connection: None, - addresses: AddressStore::from_iter( - vec![Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ), - }, - ); - drop(peers); - - peer - }; - - match manager.dial(peer).await { - Err(Error::AlreadyConnected) => {}, - _ => panic!("invalid return value"), - } - } - - #[tokio::test] - async fn peer_already_being_dialed() { - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - let peer = { - let peer = PeerId::random(); - let mut peers = manager.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Dialing { - record: AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - }, - secondary_connection: None, - addresses: AddressStore::from_iter( - vec![Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ), - }, - ); - drop(peers); - - peer - }; - - manager.dial(peer).await.unwrap(); - } - - #[tokio::test] - async fn pending_connection_for_disconnected_peer() { - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - let peer = { - let peer = PeerId::random(); - let mut peers = manager.peers.write(); - - peers.insert( - peer, - PeerContext { - state: PeerState::Disconnected { - dial_record: Some( - AddressRecord::from_multiaddr( - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer))), - ) - .unwrap(), - ), - }, - secondary_connection: None, - addresses: AddressStore::new(), - }, - ); - drop(peers); - - peer - }; - - manager.dial(peer).await.unwrap(); - } - - #[tokio::test] - async fn dial_address_invalid_transport() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - // transport doesn't start with ip/dns - { - let address = Multiaddr::empty().with(Protocol::P2p(Multihash::from(PeerId::random()))); - match manager.dial_address(address.clone()).await { - Err(Error::TransportNotSupported(dial_address)) => { - assert_eq!(dial_address, address); - }, - _ => panic!("invalid return value"), - } - } - - { - // upd-based protocol but not quic - let address = Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(8888)) - .with(Protocol::Utp) - .with(Protocol::P2p(Multihash::from(PeerId::random()))); - match manager.dial_address(address.clone()).await { - Err(Error::TransportNotSupported(dial_address)) => { - assert_eq!(dial_address, address); - }, - res => panic!("invalid return value: {res:?}"), - } - } - - // not tcp nor udp - { - let address = Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Sctp(8888)) - .with(Protocol::P2p(Multihash::from(PeerId::random()))); - match manager.dial_address(address.clone()).await { - Err(Error::TransportNotSupported(dial_address)) => { - assert_eq!(dial_address, address); - }, - _ => panic!("invalid return value"), - } - } - - // random protocol after tcp - { - let address = Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::Utp) - .with(Protocol::P2p(Multihash::from(PeerId::random()))); - match manager.dial_address(address.clone()).await { - Err(Error::TransportNotSupported(dial_address)) => { - assert_eq!(dial_address, address); - }, - _ => panic!("invalid return value"), - } - } - } - - #[tokio::test] - async fn dial_address_peer_id_missing() { - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - - async fn call_manager(manager: &mut TransportManager, address: Multiaddr) { - match manager.dial_address(address).await { - Err(Error::AddressError(AddressError::PeerIdMissing)) => {}, - _ => panic!("invalid return value"), - } - } - - { - call_manager( - &mut manager, - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)), - ) - .await; - } - - { - call_manager( - &mut manager, - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::Wss(std::borrow::Cow::Owned("".to_string()))), - ) - .await; - } - - { - call_manager( - &mut manager, - Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(8888)) - .with(Protocol::QuicV1), - ) - .await; - } - } - - #[tokio::test] - async fn inbound_connection_while_dialing() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - let connection_id = ConnectionId::random(); - let transport = Box::new({ - let mut transport = DummyTransport::new(); - transport.inject_event(TransportEvent::ConnectionEstablished { - peer, - endpoint: Endpoint::listener(dial_address.clone(), connection_id), - }); - transport - }); - manager.register_transport(SupportedTransport::Tcp, transport); - manager.add_known_address( - peer, - vec![Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 5))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ); - - assert!(manager.dial(peer).await.is_ok()); - assert!(!manager.pending_connections.is_empty()); - - { - let peers = manager.peers.read(); - - match peers.get(&peer) { - Some(PeerContext { state: PeerState::Opening { .. }, .. }) => {}, - state => panic!("invalid state for peer: {state:?}"), - } - } - - match manager.next().await.unwrap() { - TransportEvent::ConnectionEstablished { - peer: event_peer, - endpoint: event_endpoint, - .. - } => { - assert_eq!(peer, event_peer); - assert_eq!(event_endpoint, Endpoint::listener(dial_address.clone(), connection_id),); - }, - event => panic!("invalid event: {event:?}"), - } - assert!(manager.pending_connections.is_empty()); - - let peers = manager.peers.read(); - match peers.get(&peer).unwrap() { - PeerContext { - state: PeerState::Connected { record, dial_record }, - secondary_connection, - addresses, - } => { - assert!(!addresses.contains(record.address())); - assert!(dial_record.is_none()); - assert!(secondary_connection.is_none()); - assert_eq!(record.address(), &dial_address); - assert_eq!(record.connection_id(), &Some(connection_id)); - }, - state => panic!("invalid peer state: {state:?}"), - } - } - - #[tokio::test] - async fn inbound_connection_for_same_address_while_dialing() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let peer = PeerId::random(); - let dial_address = Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).unwrap())); - - let connection_id = ConnectionId::random(); - let transport = Box::new({ - let mut transport = DummyTransport::new(); - transport.inject_event(TransportEvent::ConnectionEstablished { - peer, - endpoint: Endpoint::listener(dial_address.clone(), connection_id), - }); - transport - }); - manager.register_transport(SupportedTransport::Tcp, transport); - manager.add_known_address( - peer, - vec![Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from(peer)))] - .into_iter(), - ); - - assert!(manager.dial(peer).await.is_ok()); - assert!(!manager.pending_connections.is_empty()); - - { - let peers = manager.peers.read(); - - match peers.get(&peer) { - Some(PeerContext { state: PeerState::Opening { .. }, .. }) => {}, - state => panic!("invalid state for peer: {state:?}"), - } - } - - match manager.next().await.unwrap() { - TransportEvent::ConnectionEstablished { - peer: event_peer, - endpoint: event_endpoint, - .. - } => { - assert_eq!(peer, event_peer); - assert_eq!(event_endpoint, Endpoint::listener(dial_address.clone(), connection_id),); - }, - event => panic!("invalid event: {event:?}"), - } - assert!(manager.pending_connections.is_empty()); - - let peers = manager.peers.read(); - match peers.get(&peer).unwrap() { - PeerContext { - state: PeerState::Connected { record, dial_record }, - secondary_connection, - addresses, - } => { - assert!(addresses.is_empty()); - assert!(dial_record.is_none()); - assert!(secondary_connection.is_none()); - assert_eq!(record.address(), &dial_address); - assert_eq!(record.connection_id(), &Some(connection_id)); - }, - state => panic!("invalid peer state: {state:?}"), - } - } + use super::*; + use crate::{ + crypto::ed25519::Keypair, executor::DefaultExecutor, transport::dummy::DummyTransport, + }; + use std::{ + net::{Ipv4Addr, Ipv6Addr}, + sync::Arc, + }; + + #[test] + #[should_panic] + #[cfg(debug_assertions)] + fn duplicate_protocol() { + let sink = BandwidthSink::new(); + let (mut manager, _handle) = + TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + + manager.register_protocol( + ProtocolName::from("/notif/1"), + Vec::new(), + ProtocolCodec::UnsignedVarint(None), + ); + manager.register_protocol( + ProtocolName::from("/notif/1"), + Vec::new(), + ProtocolCodec::UnsignedVarint(None), + ); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] + fn fallback_protocol_as_duplicate_main_protocol() { + let sink = BandwidthSink::new(); + let (mut manager, _handle) = + TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + + manager.register_protocol( + ProtocolName::from("/notif/1"), + Vec::new(), + ProtocolCodec::UnsignedVarint(None), + ); + manager.register_protocol( + ProtocolName::from("/notif/2"), + vec![ + ProtocolName::from("/notif/2/new"), + ProtocolName::from("/notif/1"), + ], + ProtocolCodec::UnsignedVarint(None), + ); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] + fn duplicate_fallback_protocol() { + let sink = BandwidthSink::new(); + let (mut manager, _handle) = + TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + + manager.register_protocol( + ProtocolName::from("/notif/1"), + vec![ + ProtocolName::from("/notif/1/new"), + ProtocolName::from("/notif/1"), + ], + ProtocolCodec::UnsignedVarint(None), + ); + manager.register_protocol( + ProtocolName::from("/notif/2"), + vec![ + ProtocolName::from("/notif/2/new"), + ProtocolName::from("/notif/1/new"), + ], + ProtocolCodec::UnsignedVarint(None), + ); + } + + #[test] + #[should_panic] + #[cfg(debug_assertions)] + fn duplicate_transport() { + let sink = BandwidthSink::new(); + let (mut manager, _handle) = + TransportManager::new(Keypair::generate(), HashSet::new(), sink, 8usize); + + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + } + + #[tokio::test] + async fn tried_to_self_using_peer_id() { + let keypair = Keypair::generate(); + let local_peer_id = PeerId::from_public_key(&keypair.public().into()); + let sink = BandwidthSink::new(); + let (mut manager, _handle) = TransportManager::new(keypair, HashSet::new(), sink, 8usize); + + assert!(manager.dial(local_peer_id).await.is_err()); + } + + #[tokio::test] + async fn try_to_dial_over_disabled_transport() { + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(8888)) + .with(Protocol::QuicV1) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + + assert!(std::matches!( + manager.dial_address(address).await, + Err(Error::TransportNotSupported(_)) + )); + } + + #[tokio::test] + async fn successful_dial_reported_to_transport_manager() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + let transport = Box::new({ + let mut transport = DummyTransport::new(); + transport.inject_event(TransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::dialer(dial_address.clone(), ConnectionId::from(0usize)), + }); + transport + }); + manager.register_transport(SupportedTransport::Tcp, transport); + + assert!(manager.dial_address(dial_address.clone()).await.is_ok()); + assert!(!manager.pending_connections.is_empty()); + + { + let peers = manager.peers.read(); + + match peers.get(&peer) { + Some(PeerContext { + state: PeerState::Dialing { .. }, + .. + }) => {} + state => panic!("invalid state for peer: {state:?}"), + } + } + + match manager.next().await.unwrap() { + TransportEvent::ConnectionEstablished { + peer: event_peer, + endpoint: event_endpoint, + .. + } => { + assert_eq!(peer, event_peer); + assert_eq!( + event_endpoint, + Endpoint::dialer(dial_address.clone(), ConnectionId::from(0usize)) + ) + } + event => panic!("invalid event: {event:?}"), + } + } + + #[tokio::test] + async fn try_to_dial_same_peer_twice() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + assert!(manager.dial_address(dial_address.clone()).await.is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + + assert!(manager.dial_address(dial_address.clone()).await.is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + } + + #[tokio::test] + async fn try_to_dial_same_peer_twice_diffrent_address() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + + assert!(manager + .dial_address( + Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )) + ) + .await + .is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + + assert!(manager + .dial_address( + Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )) + ) + .await + .is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + } + + #[tokio::test] + async fn dial_non_existent_peer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + assert!(manager.dial(PeerId::random()).await.is_err()); + } + + #[tokio::test] + async fn dial_non_peer_with_no_known_addresses() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + manager.peers.write().insert( + peer, + PeerContext { + state: PeerState::Disconnected { dial_record: None }, + addresses: AddressStore::new(), + secondary_connection: None, + }, + ); + + assert!(manager.dial(peer).await.is_err()); + } + + #[tokio::test] + async fn check_supported_transport_when_adding_known_address() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (_manager, handle) = TransportManager::new( + Keypair::generate(), + HashSet::from_iter([SupportedTransport::Tcp, SupportedTransport::Quic]), + BandwidthSink::new(), + 8usize, + ); + + // ipv6 + let address = Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + assert!(handle.supported_transport(&address)); + + // ipv4 + let address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + assert!(handle.supported_transport(&address)); + + // quic + let address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(8888)) + .with(Protocol::QuicV1) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + assert!(handle.supported_transport(&address)); + + // websocket + let address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))); + assert!(!handle.supported_transport(&address)); + + // websocket secure + let address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::Wss(std::borrow::Cow::Owned("/".to_string()))); + assert!(!handle.supported_transport(&address)); + } + + // local node tried to dial a node and it failed but in the mean + // time the remote node dialed local node and that succeeded. + #[tokio::test] + async fn on_dial_failure_already_connected() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let connect_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + assert!(manager.dial_address(dial_address.clone()).await.is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + + match &manager.peers.read().get(&peer).unwrap().state { + PeerState::Dialing { record } => { + assert_eq!(record.address(), &dial_address); + } + state => panic!("invalid state for peer: {state:?}"), + } + + // remote peer connected to local node from a different address that was dialed + manager + .on_connection_established( + peer, + &Endpoint::dialer(connect_address, ConnectionId::from(1usize)), + ) + .unwrap(); + + // dialing the peer failed + manager.on_dial_failure(ConnectionId::from(0usize)).unwrap(); + + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { dial_record, .. } => { + assert!(dial_record.is_none()); + assert!(peer.addresses.contains(&dial_address)); + } + state => panic!("invalid state: {state:?}"), + } + } + + // local node tried to dial a node and it failed but in the mean + // time the remote node dialed local node and that succeeded. + // + // while the dial was still in progresss, the remote node disconnected after which + // the dial failure was reported. + #[tokio::test] + async fn on_dial_failure_already_connected_and_disconnected() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let connect_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + assert!(manager.dial_address(dial_address.clone()).await.is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + + match &manager.peers.read().get(&peer).unwrap().state { + PeerState::Dialing { record } => { + assert_eq!(record.address(), &dial_address); + } + state => panic!("invalid state for peer: {state:?}"), + } + + // remote peer connected to local node from a different address that was dialed + manager + .on_connection_established( + peer, + &Endpoint::listener(connect_address, ConnectionId::from(1usize)), + ) + .unwrap(); + + // connection to remote was closed while the dial was still in progress + manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); + + // verify that the peer state is `Disconnected` + { + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Disconnected { + dial_record: Some(dial_record), + .. + } => { + assert_eq!(dial_record.address(), &dial_address); + } + state => panic!("invalid state: {state:?}"), + } + } + + // dialing the peer failed + manager.on_dial_failure(ConnectionId::from(0usize)).unwrap(); + + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Disconnected { + dial_record: None, .. + } => { + assert!(peer.addresses.contains(&dial_address)); + } + state => panic!("invalid state: {state:?}"), + } + } + + // local node tried to dial a node and it failed but in the mean + // time the remote node dialed local node and that succeeded. + // + // while the dial was still in progresss, the remote node disconnected after which + // the dial failure was reported. + #[tokio::test] + async fn on_dial_success_while_connected_and_disconnected() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let _handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let connect_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + assert!(manager.dial_address(dial_address.clone()).await.is_ok()); + assert_eq!(manager.pending_connections.len(), 1); + + match &manager.peers.read().get(&peer).unwrap().state { + PeerState::Dialing { record } => { + assert_eq!(record.address(), &dial_address); + } + state => panic!("invalid state for peer: {state:?}"), + } + + // remote peer connected to local node from a different address that was dialed + manager + .on_connection_established( + peer, + &Endpoint::listener(connect_address, ConnectionId::from(1usize)), + ) + .unwrap(); + + // connection to remote was closed while the dial was still in progress + manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); + + // verify that the peer state is `Disconnected` + { + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Disconnected { + dial_record: Some(dial_record), + .. + } => { + assert_eq!(dial_record.address(), &dial_address); + } + state => panic!("invalid state: {state:?}"), + } + } + + // the original dial succeeded + manager + .on_connection_established( + peer, + &Endpoint::dialer(dial_address, ConnectionId::from(0usize)), + ) + .unwrap(); + + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { + dial_record: None, .. + } => {} + state => panic!("invalid state: {state:?}"), + } + } + + #[tokio::test] + async fn secondary_connection_is_tracked() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let address1 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let address2 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let address3 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 10, 64))) + .with(Protocol::Tcp(9999)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + // remote peer connected to local node + manager + .on_connection_established( + peer, + &Endpoint::listener(address1, ConnectionId::from(0usize)), + ) + .unwrap(); + + // verify that the peer state is `Connected` with no seconary connection + { + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { + dial_record: None, .. + } => { + assert!(peer.secondary_connection.is_none()); + } + state => panic!("invalid state: {state:?}"), + } + } + + // second connection is established, verify that the seconary connection is tracked + manager + .on_connection_established( + peer, + &Endpoint::listener(address2.clone(), ConnectionId::from(1usize)), + ) + .unwrap(); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, .. + } => { + let seconary_connection = context.secondary_connection.as_ref().unwrap(); + assert_eq!(seconary_connection.address(), &address2); + assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); + } + state => panic!("invalid state: {state:?}"), + } + drop(peers); + + // tertiary connection is ignored + manager + .on_connection_established( + peer, + &Endpoint::listener(address3.clone(), ConnectionId::from(2usize)), + ) + .unwrap(); + + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { + dial_record: None, .. + } => { + let seconary_connection = peer.secondary_connection.as_ref().unwrap(); + assert_eq!(seconary_connection.address(), &address2); + assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); + assert!(peer.addresses.contains(&address3)); + } + state => panic!("invalid state: {state:?}"), + } + } + + #[tokio::test] + async fn secondary_connection_closed() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let address1 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let address2 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + // remote peer connected to local node + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::listener(address1, ConnectionId::from(0usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Accept + )); + + // verify that the peer state is `Connected` with no seconary connection + { + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { + dial_record: None, .. + } => { + assert!(peer.secondary_connection.is_none()); + } + state => panic!("invalid state: {state:?}"), + } + } + + // second connection is established, verify that the seconary connection is tracked + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::dialer(address2.clone(), ConnectionId::from(1usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Accept + )); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, .. + } => { + let seconary_connection = context.secondary_connection.as_ref().unwrap(); + assert_eq!(seconary_connection.address(), &address2); + assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); + } + state => panic!("invalid state: {state:?}"), + } + drop(peers); + + // close the secondary connection and verify that the peer remains connected + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(1usize)).unwrap(); + assert!(emit_event.is_none()); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, + record, + } => { + assert!(context.secondary_connection.is_none()); + assert!(context.addresses.contains(&address2)); + assert_eq!(record.connection_id(), &Some(ConnectionId::from(0usize))); + } + state => panic!("invalid state: {state:?}"), + } + } + + #[tokio::test] + async fn switch_to_secondary_connection() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let address1 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let address2 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + // remote peer connected to local node + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::listener(address1.clone(), ConnectionId::from(0usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Accept + )); + + // verify that the peer state is `Connected` with no seconary connection + { + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { + dial_record: None, .. + } => { + assert!(peer.secondary_connection.is_none()); + } + state => panic!("invalid state: {state:?}"), + } + } + + // second connection is established, verify that the seconary connection is tracked + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::dialer(address2.clone(), ConnectionId::from(1usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Accept + )); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, .. + } => { + let seconary_connection = context.secondary_connection.as_ref().unwrap(); + assert_eq!(seconary_connection.address(), &address2); + assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); + } + state => panic!("invalid state: {state:?}"), + } + drop(peers); + + // close the primary connection and verify that the peer remains connected + // while the primary connection address is stored in peer addresses + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(0usize)).unwrap(); + assert!(emit_event.is_none()); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, + record, + } => { + assert!(context.secondary_connection.is_none()); + assert!(context.addresses.contains(&address1)); + assert_eq!(record.connection_id(), &Some(ConnectionId::from(1usize))); + } + state => panic!("invalid state: {state:?}"), + } + } + + // two connections already exist and a third was opened which is ignored by + // `on_connection_established()`, when that connection is closed, verify that + // it's handled gracefully + #[tokio::test] + async fn tertiary_connection_closed() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + let peer = PeerId::random(); + let address1 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let address2 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + let address3 = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 173))) + .with(Protocol::Tcp(9999)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + // remote peer connected to local node + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::listener(address1, ConnectionId::from(0usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Accept + )); + + // verify that the peer state is `Connected` with no seconary connection + { + let peers = manager.peers.read(); + let peer = peers.get(&peer).unwrap(); + + match &peer.state { + PeerState::Connected { + dial_record: None, .. + } => { + assert!(peer.secondary_connection.is_none()); + } + state => panic!("invalid state: {state:?}"), + } + } + + // second connection is established, verify that the seconary connection is tracked + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::dialer(address2.clone(), ConnectionId::from(1usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Accept + )); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, .. + } => { + let seconary_connection = context.secondary_connection.as_ref().unwrap(); + assert_eq!(seconary_connection.address(), &address2); + assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); + } + state => panic!("invalid state: {state:?}"), + } + drop(peers); + + // third connection is established, verify that it's discarded + let emit_event = manager + .on_connection_established( + peer, + &Endpoint::listener(address3.clone(), ConnectionId::from(2usize)), + ) + .unwrap(); + assert!(std::matches!( + emit_event, + ConnectionEstablishedResult::Reject + )); + + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + assert!(context.addresses.contains(&address3)); + drop(peers); + + // close the tertiary connection that was ignored + let emit_event = manager.on_connection_closed(peer, ConnectionId::from(2usize)).unwrap(); + assert!(emit_event.is_none()); + + // verify that the state remains unchanged + let peers = manager.peers.read(); + let context = peers.get(&peer).unwrap(); + + match &context.state { + PeerState::Connected { + dial_record: None, .. + } => { + let seconary_connection = context.secondary_connection.as_ref().unwrap(); + assert_eq!(seconary_connection.address(), &address2); + assert_eq!(seconary_connection.score(), SCORE_DIAL_SUCCESS); + } + state => panic!("invalid state: {state:?}"), + } + drop(peers); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn dial_failure_for_unknow_connection() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + manager.on_dial_failure(ConnectionId::random()).unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn dial_failure_for_unknow_peer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let connection_id = ConnectionId::random(); + let peer = PeerId::random(); + manager.pending_connections.insert(connection_id, peer); + manager.on_dial_failure(connection_id).unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn connection_closed_for_unknown_peer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + manager.on_connection_closed(PeerId::random(), ConnectionId::random()).unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn unknown_connection_opened() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + manager + .on_connection_opened( + SupportedTransport::Tcp, + ConnectionId::random(), + Multiaddr::empty(), + ) + .unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn connection_opened_for_unknown_peer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let connection_id = ConnectionId::random(); + let peer = PeerId::random(); + + manager.pending_connections.insert(connection_id, peer); + manager + .on_connection_opened(SupportedTransport::Tcp, connection_id, Multiaddr::empty()) + .unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn connection_established_for_wrong_peer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let connection_id = ConnectionId::random(); + let peer = PeerId::random(); + + manager.pending_connections.insert(connection_id, peer); + manager + .on_connection_established( + PeerId::random(), + &Endpoint::dialer(Multiaddr::empty(), connection_id), + ) + .unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn open_failure_unknown_connection() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + manager + .on_open_failure(SupportedTransport::Tcp, ConnectionId::random()) + .unwrap(); + } + + #[tokio::test] + #[cfg(debug_assertions)] + #[should_panic] + async fn open_failure_unknown_peer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let connection_id = ConnectionId::random(); + let peer = PeerId::random(); + + manager.pending_connections.insert(connection_id, peer); + manager.on_open_failure(SupportedTransport::Tcp, connection_id).unwrap(); + } + + #[tokio::test] + async fn no_transports() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + assert!(manager.next().await.is_none()); + } + + #[tokio::test] + async fn dial_already_connected_peer() { + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let peer = { + let peer = PeerId::random(); + let mut peers = manager.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Connected { + record: AddressRecord::from_multiaddr( + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ) + .unwrap(), + dial_record: None, + }, + secondary_connection: None, + addresses: AddressStore::from_iter( + vec![Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ), + }, + ); + drop(peers); + + peer + }; + + match manager.dial(peer).await { + Err(Error::AlreadyConnected) => {} + _ => panic!("invalid return value"), + } + } + + #[tokio::test] + async fn peer_already_being_dialed() { + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let peer = { + let peer = PeerId::random(); + let mut peers = manager.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Dialing { + record: AddressRecord::from_multiaddr( + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ) + .unwrap(), + }, + secondary_connection: None, + addresses: AddressStore::from_iter( + vec![Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ), + }, + ); + drop(peers); + + peer + }; + + manager.dial(peer).await.unwrap(); + } + + #[tokio::test] + async fn pending_connection_for_disconnected_peer() { + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + let peer = { + let peer = PeerId::random(); + let mut peers = manager.peers.write(); + + peers.insert( + peer, + PeerContext { + state: PeerState::Disconnected { + dial_record: Some( + AddressRecord::from_multiaddr( + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))), + ) + .unwrap(), + ), + }, + secondary_connection: None, + addresses: AddressStore::new(), + }, + ); + drop(peers); + + peer + }; + + manager.dial(peer).await.unwrap(); + } + + #[tokio::test] + async fn dial_address_invalid_transport() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + // transport doesn't start with ip/dns + { + let address = Multiaddr::empty().with(Protocol::P2p(Multihash::from(PeerId::random()))); + match manager.dial_address(address.clone()).await { + Err(Error::TransportNotSupported(dial_address)) => { + assert_eq!(dial_address, address); + } + _ => panic!("invalid return value"), + } + } + + { + // upd-based protocol but not quic + let address = Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(8888)) + .with(Protocol::Utp) + .with(Protocol::P2p(Multihash::from(PeerId::random()))); + match manager.dial_address(address.clone()).await { + Err(Error::TransportNotSupported(dial_address)) => { + assert_eq!(dial_address, address); + } + res => panic!("invalid return value: {res:?}"), + } + } + + // not tcp nor udp + { + let address = Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Sctp(8888)) + .with(Protocol::P2p(Multihash::from(PeerId::random()))); + match manager.dial_address(address.clone()).await { + Err(Error::TransportNotSupported(dial_address)) => { + assert_eq!(dial_address, address); + } + _ => panic!("invalid return value"), + } + } + + // random protocol after tcp + { + let address = Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::Utp) + .with(Protocol::P2p(Multihash::from(PeerId::random()))); + match manager.dial_address(address.clone()).await { + Err(Error::TransportNotSupported(dial_address)) => { + assert_eq!(dial_address, address); + } + _ => panic!("invalid return value"), + } + } + } + + #[tokio::test] + async fn dial_address_peer_id_missing() { + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + + async fn call_manager(manager: &mut TransportManager, address: Multiaddr) { + match manager.dial_address(address).await { + Err(Error::AddressError(AddressError::PeerIdMissing)) => {} + _ => panic!("invalid return value"), + } + } + + { + call_manager( + &mut manager, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)), + ) + .await; + } + + { + call_manager( + &mut manager, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::Wss(std::borrow::Cow::Owned("".to_string()))), + ) + .await; + } + + { + call_manager( + &mut manager, + Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(8888)) + .with(Protocol::QuicV1), + ) + .await; + } + } + + #[tokio::test] + async fn inbound_connection_while_dialing() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + let connection_id = ConnectionId::random(); + let transport = Box::new({ + let mut transport = DummyTransport::new(); + transport.inject_event(TransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::listener(dial_address.clone(), connection_id), + }); + transport + }); + manager.register_transport(SupportedTransport::Tcp, transport); + manager.add_known_address( + peer, + vec![Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(192, 168, 1, 5))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ); + + assert!(manager.dial(peer).await.is_ok()); + assert!(!manager.pending_connections.is_empty()); + + { + let peers = manager.peers.read(); + + match peers.get(&peer) { + Some(PeerContext { + state: PeerState::Opening { .. }, + .. + }) => {} + state => panic!("invalid state for peer: {state:?}"), + } + } + + match manager.next().await.unwrap() { + TransportEvent::ConnectionEstablished { + peer: event_peer, + endpoint: event_endpoint, + .. + } => { + assert_eq!(peer, event_peer); + assert_eq!( + event_endpoint, + Endpoint::listener(dial_address.clone(), connection_id), + ); + } + event => panic!("invalid event: {event:?}"), + } + assert!(manager.pending_connections.is_empty()); + + let peers = manager.peers.read(); + match peers.get(&peer).unwrap() { + PeerContext { + state: + PeerState::Connected { + record, + dial_record, + }, + secondary_connection, + addresses, + } => { + assert!(!addresses.contains(record.address())); + assert!(dial_record.is_none()); + assert!(secondary_connection.is_none()); + assert_eq!(record.address(), &dial_address); + assert_eq!(record.connection_id(), &Some(connection_id)); + } + state => panic!("invalid peer state: {state:?}"), + } + } + + #[tokio::test] + async fn inbound_connection_for_same_address_while_dialing() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + let connection_id = ConnectionId::random(); + let transport = Box::new({ + let mut transport = DummyTransport::new(); + transport.inject_event(TransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::listener(dial_address.clone(), connection_id), + }); + transport + }); + manager.register_transport(SupportedTransport::Tcp, transport); + manager.add_known_address( + peer, + vec![Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer)))] + .into_iter(), + ); + + assert!(manager.dial(peer).await.is_ok()); + assert!(!manager.pending_connections.is_empty()); + + { + let peers = manager.peers.read(); + + match peers.get(&peer) { + Some(PeerContext { + state: PeerState::Opening { .. }, + .. + }) => {} + state => panic!("invalid state for peer: {state:?}"), + } + } + + match manager.next().await.unwrap() { + TransportEvent::ConnectionEstablished { + peer: event_peer, + endpoint: event_endpoint, + .. + } => { + assert_eq!(peer, event_peer); + assert_eq!( + event_endpoint, + Endpoint::listener(dial_address.clone(), connection_id), + ); + } + event => panic!("invalid event: {event:?}"), + } + assert!(manager.pending_connections.is_empty()); + + let peers = manager.peers.read(); + match peers.get(&peer).unwrap() { + PeerContext { + state: + PeerState::Connected { + record, + dial_record, + }, + secondary_connection, + addresses, + } => { + assert!(addresses.is_empty()); + assert!(dial_record.is_none()); + assert!(secondary_connection.is_none()); + assert_eq!(record.address(), &dial_address); + assert_eq!(record.connection_id(), &Some(connection_id)); + } + state => panic!("invalid peer state: {state:?}"), + } + } } diff --git a/src/transport/manager/types.rs b/src/transport/manager/types.rs index 887dc5fb..8bcdcdb9 100644 --- a/src/transport/manager/types.rs +++ b/src/transport/manager/types.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - transport::manager::address::{AddressRecord, AddressStore}, - types::ConnectionId, + transport::manager::address::{AddressRecord, AddressStore}, + types::ConnectionId, }; use multiaddr::Multiaddr; @@ -30,76 +30,76 @@ use std::collections::{HashMap, HashSet}; /// Supported protocols. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub enum SupportedTransport { - /// TCP. - Tcp, + /// TCP. + Tcp, - /// QUIC. - Quic, + /// QUIC. + Quic, - /// WebRTC - WebRtc, + /// WebRTC + WebRtc, - /// WebSocket - WebSocket, + /// WebSocket + WebSocket, } /// Peer state. #[derive(Debug)] pub enum PeerState { - /// `Litep2p` is connected to peer. - Connected { - /// Address record. - record: AddressRecord, - - /// Dial address, if it exists. - /// - /// While the local node was dialing a remote peer, the remote peer might've dialed - /// the local node and connection was established successfully. This dial address - /// is stored for processing later when the dial attempt conclused as either - /// successful/failed. - dial_record: Option, - }, - - /// Connection to peer is opening over one or more addresses. - Opening { - /// Address records used for dialing. - records: HashMap, - - /// Connection ID. - connection_id: ConnectionId, - - /// Active transports. - transports: HashSet, - }, - - /// Peer is being dialed. - Dialing { - /// Address record. - record: AddressRecord, - }, - - /// `Litep2p` is not connected to peer. - Disconnected { - /// Dial address, if it exists. - /// - /// While the local node was dialing a remote peer, the remote peer might've dialed - /// the local node and connection was established successfully. The connection might've - /// been closed before the dial concluded which means that - /// [`crate::transport::manager::TransportManager`] must be prepared to handle the dial - /// failure even after the connection has been closed. - dial_record: Option, - }, + /// `Litep2p` is connected to peer. + Connected { + /// Address record. + record: AddressRecord, + + /// Dial address, if it exists. + /// + /// While the local node was dialing a remote peer, the remote peer might've dialed + /// the local node and connection was established successfully. This dial address + /// is stored for processing later when the dial attempt conclused as either + /// successful/failed. + dial_record: Option, + }, + + /// Connection to peer is opening over one or more addresses. + Opening { + /// Address records used for dialing. + records: HashMap, + + /// Connection ID. + connection_id: ConnectionId, + + /// Active transports. + transports: HashSet, + }, + + /// Peer is being dialed. + Dialing { + /// Address record. + record: AddressRecord, + }, + + /// `Litep2p` is not connected to peer. + Disconnected { + /// Dial address, if it exists. + /// + /// While the local node was dialing a remote peer, the remote peer might've dialed + /// the local node and connection was established successfully. The connection might've + /// been closed before the dial concluded which means that + /// [`crate::transport::manager::TransportManager`] must be prepared to handle the dial + /// failure even after the connection has been closed. + dial_record: Option, + }, } /// Peer context. #[derive(Debug)] pub struct PeerContext { - /// Peer state. - pub state: PeerState, + /// Peer state. + pub state: PeerState, - /// Seconary connection, if it's open. - pub secondary_connection: Option, + /// Seconary connection, if it's open. + pub secondary_connection: Option, - /// Known addresses of peer. - pub addresses: AddressStore, + /// Known addresses of peer. + pub addresses: AddressStore, } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 860b936a..16f53c63 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -47,141 +47,147 @@ pub(crate) const MAX_PARALLEL_DIALS: usize = 8; /// Connection endpoint. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Endpoint { - /// Successfully established outbound connection. - Dialer { - /// Address that was dialed. - address: Multiaddr, - - /// Connection ID. - connection_id: ConnectionId, - }, - - /// Successfully established inbound connection. - Listener { - /// Local connection address. - address: Multiaddr, - - /// Connection ID. - connection_id: ConnectionId, - }, + /// Successfully established outbound connection. + Dialer { + /// Address that was dialed. + address: Multiaddr, + + /// Connection ID. + connection_id: ConnectionId, + }, + + /// Successfully established inbound connection. + Listener { + /// Local connection address. + address: Multiaddr, + + /// Connection ID. + connection_id: ConnectionId, + }, } impl Endpoint { - /// Get `Multiaddr` of the [`Endpoint`]. - pub fn address(&self) -> &Multiaddr { - match self { - Self::Dialer { address, .. } => &address, - Self::Listener { address, .. } => &address, - } - } - - /// Crate dialer. - pub(crate) fn dialer(address: Multiaddr, connection_id: ConnectionId) -> Self { - Endpoint::Dialer { address, connection_id } - } - - /// Create listener. - pub(crate) fn listener(address: Multiaddr, connection_id: ConnectionId) -> Self { - Endpoint::Listener { address, connection_id } - } - - /// Get `ConnectionId` of the `Endpoint`. - pub fn connection_id(&self) -> ConnectionId { - match self { - Self::Dialer { connection_id, .. } => *connection_id, - Self::Listener { connection_id, .. } => *connection_id, - } - } - - /// Is this a listener endpoint? - pub fn is_listener(&self) -> bool { - return std::matches!(self, Self::Listener { .. }); - } + /// Get `Multiaddr` of the [`Endpoint`]. + pub fn address(&self) -> &Multiaddr { + match self { + Self::Dialer { address, .. } => &address, + Self::Listener { address, .. } => &address, + } + } + + /// Crate dialer. + pub(crate) fn dialer(address: Multiaddr, connection_id: ConnectionId) -> Self { + Endpoint::Dialer { + address, + connection_id, + } + } + + /// Create listener. + pub(crate) fn listener(address: Multiaddr, connection_id: ConnectionId) -> Self { + Endpoint::Listener { + address, + connection_id, + } + } + + /// Get `ConnectionId` of the `Endpoint`. + pub fn connection_id(&self) -> ConnectionId { + match self { + Self::Dialer { connection_id, .. } => *connection_id, + Self::Listener { connection_id, .. } => *connection_id, + } + } + + /// Is this a listener endpoint? + pub fn is_listener(&self) -> bool { + return std::matches!(self, Self::Listener { .. }); + } } /// Transport event. #[derive(Debug)] pub(crate) enum TransportEvent { - /// Fully negotiated connection established to remote peer. - ConnectionEstablished { - /// Peer ID. - peer: PeerId, - - /// Endpoint. - endpoint: Endpoint, - }, - - /// Connection opened to remote but not yet negotiated. - ConnectionOpened { - /// Connection ID. - connection_id: ConnectionId, - - /// Address that was dialed. - address: Multiaddr, - }, - - /// Connection closed to remote peer. - #[allow(unused)] - ConnectionClosed { - /// Peer ID. - peer: PeerId, - - /// Connection ID. - connection_id: ConnectionId, - }, - - /// Failed to dial remote peer. - DialFailure { - /// Connection ID. - connection_id: ConnectionId, - - /// Dialed address. - address: Multiaddr, - - /// Error. - error: Error, - }, - - /// Open failure for an unnegotiated set of connections. - OpenFailure { - /// Connection ID. - connection_id: ConnectionId, - }, + /// Fully negotiated connection established to remote peer. + ConnectionEstablished { + /// Peer ID. + peer: PeerId, + + /// Endpoint. + endpoint: Endpoint, + }, + + /// Connection opened to remote but not yet negotiated. + ConnectionOpened { + /// Connection ID. + connection_id: ConnectionId, + + /// Address that was dialed. + address: Multiaddr, + }, + + /// Connection closed to remote peer. + #[allow(unused)] + ConnectionClosed { + /// Peer ID. + peer: PeerId, + + /// Connection ID. + connection_id: ConnectionId, + }, + + /// Failed to dial remote peer. + DialFailure { + /// Connection ID. + connection_id: ConnectionId, + + /// Dialed address. + address: Multiaddr, + + /// Error. + error: Error, + }, + + /// Open failure for an unnegotiated set of connections. + OpenFailure { + /// Connection ID. + connection_id: ConnectionId, + }, } pub(crate) trait TransportBuilder { - type Config: Debug; - type Transport: Transport; + type Config: Debug; + type Transport: Transport; - /// Create new [`Transport`] object. - fn new(context: TransportHandle, config: Self::Config) -> crate::Result<(Self, Vec)> - where - Self: Sized; + /// Create new [`Transport`] object. + fn new(context: TransportHandle, config: Self::Config) -> crate::Result<(Self, Vec)> + where + Self: Sized; } pub(crate) trait Transport: Stream + Unpin + Send { - /// Dial `address` and negotiate connection. - fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()>; - - /// Accept negotiated connection. - fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()>; - - /// Reject negotiated connection. - fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()>; - - /// Attempt to open connection to remote peer over one or more addresses. - /// - /// TODO: documentation - fn open(&mut self, connection_id: ConnectionId, addresses: Vec) - -> crate::Result<()>; - - /// Negotiate opened connection. - /// - /// TODO: documentation - fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()>; - - /// Cancel opening connections. - /// - /// This is a no-op for connections that have already succeeded/canceled. - fn cancel(&mut self, connection_id: ConnectionId); + /// Dial `address` and negotiate connection. + fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()>; + + /// Accept negotiated connection. + fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()>; + + /// Reject negotiated connection. + fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()>; + + /// Attempt to open connection to remote peer over one or more addresses. + /// + /// TODO: documentation + fn open(&mut self, connection_id: ConnectionId, addresses: Vec) + -> crate::Result<()>; + + /// Negotiate opened connection. + /// + /// TODO: documentation + fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()>; + + /// Cancel opening connections. + /// + /// This is a no-op for connections that have already succeeded/canceled. + fn cancel(&mut self, connection_id: ConnectionId); } diff --git a/src/transport/quic/config.rs b/src/transport/quic/config.rs index 98fe1dd7..8ed30fce 100644 --- a/src/transport/quic/config.rs +++ b/src/transport/quic/config.rs @@ -29,30 +29,30 @@ use std::time::Duration; /// QUIC transport configuration. #[derive(Debug)] pub struct Config { - /// Listen address for the transport. - /// - /// Default listen addres is `/ip4/127.0.0.1/udp/0/quic-v1`. - pub listen_addresses: Vec, - - /// Connection open timeout. - /// - /// How long should litep2p wait for a connection to be opend before the host - /// is deemed unreachable. - pub connection_open_timeout: Duration, - - /// Substream open timeout. - /// - /// How long should litep2p wait for a substream to be opened before considering - /// the substream rejected. - pub substream_open_timeout: Duration, + /// Listen address for the transport. + /// + /// Default listen addres is `/ip4/127.0.0.1/udp/0/quic-v1`. + pub listen_addresses: Vec, + + /// Connection open timeout. + /// + /// How long should litep2p wait for a connection to be opend before the host + /// is deemed unreachable. + pub connection_open_timeout: Duration, + + /// Substream open timeout. + /// + /// How long should litep2p wait for a substream to be opened before considering + /// the substream rejected. + pub substream_open_timeout: Duration, } impl Default for Config { - fn default() -> Self { - Self { - listen_addresses: vec!["/ip4/127.0.0.1/udp/0/quic-v1".parse().expect("valid address")], - connection_open_timeout: CONNECTION_OPEN_TIMEOUT, - substream_open_timeout: SUBSTREAM_OPEN_TIMEOUT, - } - } + fn default() -> Self { + Self { + listen_addresses: vec!["/ip4/127.0.0.1/udp/0/quic-v1".parse().expect("valid address")], + connection_open_timeout: CONNECTION_OPEN_TIMEOUT, + substream_open_timeout: SUBSTREAM_OPEN_TIMEOUT, + } + } } diff --git a/src/transport/quic/connection.rs b/src/transport/quic/connection.rs index 69def7a9..ec14d772 100644 --- a/src/transport/quic/connection.rs +++ b/src/transport/quic/connection.rs @@ -23,17 +23,17 @@ use std::time::Duration; use crate::{ - config::Role, - error::Error, - multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, - protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, - substream, - transport::{ - quic::substream::{NegotiatingSubstream, Substream}, - Endpoint, - }, - types::{protocol::ProtocolName, SubstreamId}, - BandwidthSink, PeerId, + config::Role, + error::Error, + multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, + protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, + substream, + transport::{ + quic::substream::{NegotiatingSubstream, Substream}, + Endpoint, + }, + types::{protocol::ProtocolName, SubstreamId}, + BandwidthSink, PeerId, }; use futures::{future::BoxFuture, stream::FuturesUnordered, AsyncRead, AsyncWrite, StreamExt}; @@ -45,340 +45,340 @@ const LOG_TARGET: &str = "litep2p::quic::connection"; /// QUIC connection error. #[derive(Debug)] enum ConnectionError { - /// Timeout - Timeout { - /// Protocol. - protocol: Option, - - /// Substream ID. - substream_id: Option, - }, - - /// Failed to negotiate connection/substream. - FailedToNegotiate { - /// Protocol. - protocol: Option, - - /// Substream ID. - substream_id: Option, - - /// Error. - error: Error, - }, + /// Timeout + Timeout { + /// Protocol. + protocol: Option, + + /// Substream ID. + substream_id: Option, + }, + + /// Failed to negotiate connection/substream. + FailedToNegotiate { + /// Protocol. + protocol: Option, + + /// Substream ID. + substream_id: Option, + + /// Error. + error: Error, + }, } struct NegotiatedSubstream { - /// Substream direction. - direction: Direction, + /// Substream direction. + direction: Direction, - /// Substream ID. - substream_id: SubstreamId, + /// Substream ID. + substream_id: SubstreamId, - /// Protocol name. - protocol: ProtocolName, + /// Protocol name. + protocol: ProtocolName, - /// Substream used to send data. - sender: SendStream, + /// Substream used to send data. + sender: SendStream, - /// Substream used to receive data. - receiver: RecvStream, + /// Substream used to receive data. + receiver: RecvStream, - /// Permit. - permit: Permit, + /// Permit. + permit: Permit, } /// QUIC connection. pub struct QuicConnection { - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Endpoint. - endpoint: Endpoint, + /// Endpoint. + endpoint: Endpoint, - /// Substream open timeout. - substream_open_timeout: Duration, + /// Substream open timeout. + substream_open_timeout: Duration, - /// QUIC connection. - connection: QuinnConnection, + /// QUIC connection. + connection: QuinnConnection, - /// Protocol set. - protocol_set: ProtocolSet, + /// Protocol set. + protocol_set: ProtocolSet, - /// Bandwidth sink. - bandwidth_sink: BandwidthSink, + /// Bandwidth sink. + bandwidth_sink: BandwidthSink, - /// Pending substreams. - pending_substreams: - FuturesUnordered>>, + /// Pending substreams. + pending_substreams: + FuturesUnordered>>, } impl QuicConnection { - /// Creates a new [`QuicConnection`]. - pub fn new( - peer: PeerId, - endpoint: Endpoint, - connection: QuinnConnection, - protocol_set: ProtocolSet, - bandwidth_sink: BandwidthSink, - substream_open_timeout: Duration, - ) -> Self { - Self { - peer, - endpoint, - connection, - protocol_set, - bandwidth_sink, - substream_open_timeout, - pending_substreams: FuturesUnordered::new(), - } - } - - /// Negotiate protocol. - async fn negotiate_protocol( - stream: S, - role: &Role, - protocols: Vec<&str>, - ) -> crate::Result<(Negotiated, ProtocolName)> { - tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); - - let (protocol, socket) = match role { - Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await?, - Role::Listener => listener_select_proto(stream, protocols).await?, - }; - - tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); - - Ok((socket, ProtocolName::from(protocol.to_string()))) - } - - /// Open substream for `protocol`. - async fn open_substream( - handle: QuinnConnection, - permit: Permit, - substream_id: SubstreamId, - protocol: ProtocolName, - fallback_names: Vec, - ) -> crate::Result { - tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); - - let stream = match handle.open_bi().await { - Ok((send_stream, recv_stream)) => NegotiatingSubstream::new(send_stream, recv_stream), - Err(error) => return Err(Error::Quinn(error)), - }; - - // TODO: protocols don't change after they've been initialized so this should be done only - // once - let protocols = std::iter::once(&*protocol) - .chain(fallback_names.iter().map(|protocol| &**protocol)) - .collect(); - - let (io, protocol) = Self::negotiate_protocol(stream, &Role::Dialer, protocols).await?; - - tracing::trace!( - target: LOG_TARGET, - ?protocol, - ?substream_id, - "substream accepted and negotiated" - ); - - let stream = io.inner(); - let (sender, receiver) = stream.into_parts(); - - Ok(NegotiatedSubstream { - sender, - receiver, - substream_id, - direction: Direction::Outbound(substream_id), - permit, - protocol, - }) - } - - /// Accept bidirectional substream from rmeote peer. - async fn accept_substream( - stream: NegotiatingSubstream, - protocols: Vec, - substream_id: SubstreamId, - permit: Permit, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - "accept inbound substream" - ); - - let protocols = protocols.iter().map(|protocol| &**protocol).collect::>(); - let (io, protocol) = Self::negotiate_protocol(stream, &Role::Listener, protocols).await?; - - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - ?protocol, - "substream accepted and negotiated" - ); - - let stream = io.inner(); - let (sender, receiver) = stream.into_parts(); - - Ok(NegotiatedSubstream { - permit, - sender, - receiver, - protocol, - substream_id, - direction: Direction::Inbound, - }) - } - - /// Start event loop for [`QuicConnection`]. - pub async fn start(mut self) -> crate::Result<()> { - self.protocol_set - .report_connection_established(self.peer, self.endpoint.clone()) - .await?; - - loop { - tokio::select! { - event = self.connection.accept_bi() => match event { - Ok((send_stream, receive_stream)) => { - - let substream = self.protocol_set.next_substream_id(); - let protocols = self.protocol_set.protocols(); - let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - let stream = NegotiatingSubstream::new(send_stream, receive_stream); - let substream_open_timeout = self.substream_open_timeout; - - self.pending_substreams.push(Box::pin(async move { - match tokio::time::timeout( - substream_open_timeout, - Self::accept_substream(stream, protocols, substream, permit), - ) - .await - { - Ok(Ok(substream)) => Ok(substream), - Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { - protocol: None, - substream_id: None, - error, - }), - Err(_) => Err(ConnectionError::Timeout { - protocol: None, - substream_id: None - }), - } - })); - } - Err(error) => { - tracing::debug!(target: LOG_TARGET, peer = ?self.peer, ?error, "failed to accept substream"); - return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; - } - }, - substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => { - match substream { - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to accept/open substream", - ); - - let (protocol, substream_id, error) = match error { - ConnectionError::Timeout { protocol, substream_id } => { - (protocol, substream_id, Error::Timeout) - } - ConnectionError::FailedToNegotiate { protocol, substream_id, error } => { - (protocol, substream_id, error) - } - }; - - if let (Some(protocol), Some(substream_id)) = (protocol, substream_id) { - self.protocol_set - .report_substream_open_failure(protocol, substream_id, error) - .await?; - } - } - Ok(substream) => { - let protocol = substream.protocol.clone(); - let substream_id = substream.substream_id; - let direction = substream.direction; - let bandwidth_sink = self.bandwidth_sink.clone(); - let substream = substream::Substream::new_quic( - self.peer, - substream_id, - Substream::new( - substream.permit, - substream.sender, - substream.receiver, - bandwidth_sink - ), - self.protocol_set.protocol_codec(&protocol) - ); - - self.protocol_set - .report_substream_open(self.peer, protocol, direction, substream) - .await?; - } - } - } - command = self.protocol_set.next() => match command { - None => { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - connection_id = ?self.endpoint.connection_id(), - "protocols have dropped connection" - ); - return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; - } - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { - let connection = self.connection.clone(); - let substream_open_timeout = self.substream_open_timeout; - - tracing::trace!( - target: LOG_TARGET, - ?protocol, - ?fallback_names, - ?substream_id, - "open substream" - ); - - self.pending_substreams.push(Box::pin(async move { - match tokio::time::timeout( - substream_open_timeout, - Self::open_substream( - connection, - permit, - substream_id, - protocol.clone(), - fallback_names, - ), - ) - .await - { - Ok(Ok(substream)) => Ok(substream), - Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { - protocol: Some(protocol), - substream_id: Some(substream_id), - error, - }), - Err(_) => Err(ConnectionError::Timeout { - protocol: None, - substream_id: None - }), - } - })); - } - Some(ProtocolCommand::ForceClose) => { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - connection_id = ?self.endpoint.connection_id(), - "force closing connection", - ); - - return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; - } - } - } - } - } + /// Creates a new [`QuicConnection`]. + pub fn new( + peer: PeerId, + endpoint: Endpoint, + connection: QuinnConnection, + protocol_set: ProtocolSet, + bandwidth_sink: BandwidthSink, + substream_open_timeout: Duration, + ) -> Self { + Self { + peer, + endpoint, + connection, + protocol_set, + bandwidth_sink, + substream_open_timeout, + pending_substreams: FuturesUnordered::new(), + } + } + + /// Negotiate protocol. + async fn negotiate_protocol( + stream: S, + role: &Role, + protocols: Vec<&str>, + ) -> crate::Result<(Negotiated, ProtocolName)> { + tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); + + let (protocol, socket) = match role { + Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await?, + Role::Listener => listener_select_proto(stream, protocols).await?, + }; + + tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); + + Ok((socket, ProtocolName::from(protocol.to_string()))) + } + + /// Open substream for `protocol`. + async fn open_substream( + handle: QuinnConnection, + permit: Permit, + substream_id: SubstreamId, + protocol: ProtocolName, + fallback_names: Vec, + ) -> crate::Result { + tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); + + let stream = match handle.open_bi().await { + Ok((send_stream, recv_stream)) => NegotiatingSubstream::new(send_stream, recv_stream), + Err(error) => return Err(Error::Quinn(error)), + }; + + // TODO: protocols don't change after they've been initialized so this should be done only + // once + let protocols = std::iter::once(&*protocol) + .chain(fallback_names.iter().map(|protocol| &**protocol)) + .collect(); + + let (io, protocol) = Self::negotiate_protocol(stream, &Role::Dialer, protocols).await?; + + tracing::trace!( + target: LOG_TARGET, + ?protocol, + ?substream_id, + "substream accepted and negotiated" + ); + + let stream = io.inner(); + let (sender, receiver) = stream.into_parts(); + + Ok(NegotiatedSubstream { + sender, + receiver, + substream_id, + direction: Direction::Outbound(substream_id), + permit, + protocol, + }) + } + + /// Accept bidirectional substream from rmeote peer. + async fn accept_substream( + stream: NegotiatingSubstream, + protocols: Vec, + substream_id: SubstreamId, + permit: Permit, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + "accept inbound substream" + ); + + let protocols = protocols.iter().map(|protocol| &**protocol).collect::>(); + let (io, protocol) = Self::negotiate_protocol(stream, &Role::Listener, protocols).await?; + + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + ?protocol, + "substream accepted and negotiated" + ); + + let stream = io.inner(); + let (sender, receiver) = stream.into_parts(); + + Ok(NegotiatedSubstream { + permit, + sender, + receiver, + protocol, + substream_id, + direction: Direction::Inbound, + }) + } + + /// Start event loop for [`QuicConnection`]. + pub async fn start(mut self) -> crate::Result<()> { + self.protocol_set + .report_connection_established(self.peer, self.endpoint.clone()) + .await?; + + loop { + tokio::select! { + event = self.connection.accept_bi() => match event { + Ok((send_stream, receive_stream)) => { + + let substream = self.protocol_set.next_substream_id(); + let protocols = self.protocol_set.protocols(); + let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; + let stream = NegotiatingSubstream::new(send_stream, receive_stream); + let substream_open_timeout = self.substream_open_timeout; + + self.pending_substreams.push(Box::pin(async move { + match tokio::time::timeout( + substream_open_timeout, + Self::accept_substream(stream, protocols, substream, permit), + ) + .await + { + Ok(Ok(substream)) => Ok(substream), + Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { + protocol: None, + substream_id: None, + error, + }), + Err(_) => Err(ConnectionError::Timeout { + protocol: None, + substream_id: None + }), + } + })); + } + Err(error) => { + tracing::debug!(target: LOG_TARGET, peer = ?self.peer, ?error, "failed to accept substream"); + return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; + } + }, + substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => { + match substream { + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + "failed to accept/open substream", + ); + + let (protocol, substream_id, error) = match error { + ConnectionError::Timeout { protocol, substream_id } => { + (protocol, substream_id, Error::Timeout) + } + ConnectionError::FailedToNegotiate { protocol, substream_id, error } => { + (protocol, substream_id, error) + } + }; + + if let (Some(protocol), Some(substream_id)) = (protocol, substream_id) { + self.protocol_set + .report_substream_open_failure(protocol, substream_id, error) + .await?; + } + } + Ok(substream) => { + let protocol = substream.protocol.clone(); + let substream_id = substream.substream_id; + let direction = substream.direction; + let bandwidth_sink = self.bandwidth_sink.clone(); + let substream = substream::Substream::new_quic( + self.peer, + substream_id, + Substream::new( + substream.permit, + substream.sender, + substream.receiver, + bandwidth_sink + ), + self.protocol_set.protocol_codec(&protocol) + ); + + self.protocol_set + .report_substream_open(self.peer, protocol, direction, substream) + .await?; + } + } + } + command = self.protocol_set.next() => match command { + None => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + connection_id = ?self.endpoint.connection_id(), + "protocols have dropped connection" + ); + return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; + } + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + let connection = self.connection.clone(); + let substream_open_timeout = self.substream_open_timeout; + + tracing::trace!( + target: LOG_TARGET, + ?protocol, + ?fallback_names, + ?substream_id, + "open substream" + ); + + self.pending_substreams.push(Box::pin(async move { + match tokio::time::timeout( + substream_open_timeout, + Self::open_substream( + connection, + permit, + substream_id, + protocol.clone(), + fallback_names, + ), + ) + .await + { + Ok(Ok(substream)) => Ok(substream), + Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { + protocol: Some(protocol), + substream_id: Some(substream_id), + error, + }), + Err(_) => Err(ConnectionError::Timeout { + protocol: None, + substream_id: None + }), + } + })); + } + Some(ProtocolCommand::ForceClose) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + connection_id = ?self.endpoint.connection_id(), + "force closing connection", + ); + + return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await; + } + } + } + } + } } diff --git a/src/transport/quic/listener.rs b/src/transport/quic/listener.rs index 6b507244..7f6c3ad0 100644 --- a/src/transport/quic/listener.rs +++ b/src/transport/quic/listener.rs @@ -19,9 +19,9 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - crypto::{ed25519::Keypair, tls::make_server_config}, - error::{AddressError, Error}, - PeerId, + crypto::{ed25519::Keypair, tls::make_server_config}, + error::{AddressError, Error}, + PeerId, }; use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, Stream, StreamExt}; @@ -29,10 +29,10 @@ use multiaddr::{Multiaddr, Protocol}; use quinn::{Connecting, Endpoint, ServerConfig}; use std::{ - net::{IpAddr, SocketAddr}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; /// Logging target for the file. @@ -40,384 +40,400 @@ const LOG_TARGET: &str = "litep2p::quic::listener"; /// QUIC listener. pub struct QuicListener { - /// Listen addresses. - _listen_addresses: Vec, + /// Listen addresses. + _listen_addresses: Vec, - /// Listeners. - listeners: Vec, + /// Listeners. + listeners: Vec, - /// Incoming connections. - incoming: FuturesUnordered>>, + /// Incoming connections. + incoming: FuturesUnordered>>, } impl QuicListener { - /// Create new [`QuicListener`]. - pub fn new( - keypair: &Keypair, - addresses: Vec, - ) -> crate::Result<(Self, Vec)> { - let mut listeners: Vec = Vec::new(); - let mut listen_addresses = Vec::new(); - - for address in addresses.into_iter() { - let (listen_address, _) = Self::get_socket_address(&address)?; - let crypto_config = Arc::new(make_server_config(keypair).expect("to succeed")); - let server_config = ServerConfig::with_crypto(crypto_config); - let listener = Endpoint::server(server_config, listen_address).unwrap(); - - let listen_address = listener.local_addr()?; - listen_addresses.push(listen_address); - listeners.push(listener); - // ); - } - - let listen_multi_addresses = listen_addresses - .iter() - .cloned() - .map(|address| { - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Udp(address.port())) - .with(Protocol::QuicV1) - }) - .collect(); - - Ok(( - Self { - incoming: listeners - .iter_mut() - .enumerate() - .map(|(i, listener)| { - let inner = listener.clone(); - async move { inner.accept().await.map(|connecting| (i, connecting)) } - .boxed() - }) - .collect(), - listeners, - _listen_addresses: listen_addresses, - }, - listen_multi_addresses, - )) - } - - /// Extract socket address and `PeerId`, if found, from `address`. - pub fn get_socket_address(address: &Multiaddr) -> crate::Result<(SocketAddr, Option)> { - tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); - - let mut iter = address.iter(); - let socket_address = match iter.next() { - Some(Protocol::Ip6(address)) => match iter.next() { - Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V6(address), port), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `QuicV1`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - Some(Protocol::Ip4(address)) => match iter.next() { - Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V4(address), port), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `QuicV1`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - protocol => { - tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - // verify that quic exists - match iter.next() { - Some(Protocol::QuicV1) => {}, - _ => return Err(Error::AddressError(AddressError::InvalidProtocol)), - } - - let maybe_peer = match iter.next() { - Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), - None => None, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `P2p` or `None`" - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - Ok((socket_address, maybe_peer)) - } + /// Create new [`QuicListener`]. + pub fn new( + keypair: &Keypair, + addresses: Vec, + ) -> crate::Result<(Self, Vec)> { + let mut listeners: Vec = Vec::new(); + let mut listen_addresses = Vec::new(); + + for address in addresses.into_iter() { + let (listen_address, _) = Self::get_socket_address(&address)?; + let crypto_config = Arc::new(make_server_config(keypair).expect("to succeed")); + let server_config = ServerConfig::with_crypto(crypto_config); + let listener = Endpoint::server(server_config, listen_address).unwrap(); + + let listen_address = listener.local_addr()?; + listen_addresses.push(listen_address); + listeners.push(listener); + // ); + } + + let listen_multi_addresses = listen_addresses + .iter() + .cloned() + .map(|address| { + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Udp(address.port())) + .with(Protocol::QuicV1) + }) + .collect(); + + Ok(( + Self { + incoming: listeners + .iter_mut() + .enumerate() + .map(|(i, listener)| { + let inner = listener.clone(); + async move { inner.accept().await.map(|connecting| (i, connecting)) } + .boxed() + }) + .collect(), + listeners, + _listen_addresses: listen_addresses, + }, + listen_multi_addresses, + )) + } + + /// Extract socket address and `PeerId`, if found, from `address`. + pub fn get_socket_address(address: &Multiaddr) -> crate::Result<(SocketAddr, Option)> { + tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); + + let mut iter = address.iter(); + let socket_address = match iter.next() { + Some(Protocol::Ip6(address)) => match iter.next() { + Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V6(address), port), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `QuicV1`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + Some(Protocol::Ip4(address)) => match iter.next() { + Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V4(address), port), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `QuicV1`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + protocol => { + tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + // verify that quic exists + match iter.next() { + Some(Protocol::QuicV1) => {} + _ => return Err(Error::AddressError(AddressError::InvalidProtocol)), + } + + let maybe_peer = match iter.next() { + Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), + None => None, + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `P2p` or `None`" + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + Ok((socket_address, maybe_peer)) + } } impl Stream for QuicListener { - type Item = Connecting; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.incoming.is_empty() { - return Poll::Pending; - } - - match futures::ready!(self.incoming.poll_next_unpin(cx)) { - None => Poll::Ready(None), - Some(None) => Poll::Ready(None), - Some(Some((listener, future))) => { - let inner = self.listeners[listener].clone(); - self.incoming.push( - async move { inner.accept().await.map(|connecting| (listener, connecting)) } - .boxed(), - ); - - Poll::Ready(Some(future)) - }, - } - } + type Item = Connecting; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.incoming.is_empty() { + return Poll::Pending; + } + + match futures::ready!(self.incoming.poll_next_unpin(cx)) { + None => Poll::Ready(None), + Some(None) => Poll::Ready(None), + Some(Some((listener, future))) => { + let inner = self.listeners[listener].clone(); + self.incoming.push( + async move { inner.accept().await.map(|connecting| (listener, connecting)) } + .boxed(), + ); + + Poll::Ready(Some(future)) + } + } + } } #[cfg(test)] mod tests { - use crate::crypto::tls::make_client_config; - - use super::*; - use quinn::ClientConfig; - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - - #[test] - fn parse_multiaddresses() { - assert!(QuicListener::get_socket_address( - &"/ip6/::1/udp/8888/quic-v1".parse().expect("valid multiaddress") - ) - .is_ok()); - assert!(QuicListener::get_socket_address( - &"/ip4/127.0.0.1/udp/8888/quic-v1".parse().expect("valid multiaddress") - ) - .is_ok()); - assert!(QuicListener::get_socket_address( - &"/ip6/::1/udp/8888/quic-v1/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_ok()); - assert!(QuicListener::get_socket_address( + use crate::crypto::tls::make_client_config; + + use super::*; + use quinn::ClientConfig; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + + #[test] + fn parse_multiaddresses() { + assert!(QuicListener::get_socket_address( + &"/ip6/::1/udp/8888/quic-v1".parse().expect("valid multiaddress") + ) + .is_ok()); + assert!(QuicListener::get_socket_address( + &"/ip4/127.0.0.1/udp/8888/quic-v1".parse().expect("valid multiaddress") + ) + .is_ok()); + assert!(QuicListener::get_socket_address( + &"/ip6/::1/udp/8888/quic-v1/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_ok()); + assert!(QuicListener::get_socket_address( &"/ip4/127.0.0.1/udp/8888/quic-v1/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" .parse() .expect("valid multiaddress") ) .is_ok()); - assert!(QuicListener::get_socket_address( - &"/ip6/::1/tcp/8888/quic-v1/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(QuicListener::get_socket_address( - &"/ip4/127.0.0.1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(QuicListener::get_socket_address( - &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(QuicListener::get_socket_address( - &"/dns/google.com/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(QuicListener::get_socket_address( - &"/ip6/::1/udp/8888/quic-v1/utp".parse().expect("valid multiaddress") - ) - .is_err()); - } - - #[tokio::test] - async fn no_listeners() { - let (mut listener, _) = QuicListener::new(&Keypair::generate(), Vec::new()).unwrap(); - - futures::future::poll_fn(|cx| match listener.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("unexpected event: {event:?}"), - }) - .await; - } - - #[tokio::test] - async fn one_listener() { - let address: Multiaddr = "/ip6/::1/udp/0/quic-v1".parse().unwrap(); - let keypair = Keypair::generate(); - let peer = PeerId::from_public_key(&keypair.public().into()); - let (mut listener, listen_addresses) = - QuicListener::new(&keypair, vec![address.clone()]).unwrap(); - let Some(Protocol::Udp(port)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let crypto_config = - Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); - let client_config = ClientConfig::new(crypto_config); - let client = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - let connection = client - .connect_with(client_config, format!("[::1]:{port}").parse().unwrap(), "l") - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - - let (res1, res2) = tokio::join!( - listener.next(), - Box::pin(async move { - match connection.await { - Ok(connection) => Ok(connection), - Err(error) => Err(error), - } - }) - ); - - assert!(res1.is_some() && res2.is_ok()); - } - - #[tokio::test] - async fn two_listeners() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let address1: Multiaddr = "/ip6/::1/udp/0/quic-v1".parse().unwrap(); - let address2: Multiaddr = "/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap(); - let keypair = Keypair::generate(); - let peer = PeerId::from_public_key(&keypair.public().into()); - - let (mut listener, listen_addresses) = - QuicListener::new(&keypair, vec![address1, address2]).unwrap(); - - let Some(Protocol::Udp(port1)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let Some(Protocol::Udp(port2)) = - listen_addresses.iter().skip(1).next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let crypto_config1 = - Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); - let client_config1 = ClientConfig::new(crypto_config1); - let client1 = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - let connection1 = client1 - .connect_with(client_config1, format!("[::1]:{port1}").parse().unwrap(), "l") - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - - let crypto_config2 = - Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); - let client_config2 = ClientConfig::new(crypto_config2); - let client2 = Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)) - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - let connection2 = client2 - .connect_with(client_config2, format!("127.0.0.1:{port2}").parse().unwrap(), "l") - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - - tokio::spawn(async move { - match connection1.await { - Ok(connection) => Ok(connection), - Err(error) => Err(error), - } - }); - - tokio::spawn(async move { - match connection2.await { - Ok(connection) => Ok(connection), - Err(error) => Err(error), - } - }); - - for _ in 0..2 { - let _ = listener.next().await; - } - } - - #[tokio::test] - async fn two_clients_dialing_same_address() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let keypair = Keypair::generate(); - let peer = PeerId::from_public_key(&keypair.public().into()); - - let (mut listener, listen_addresses) = QuicListener::new( - &keypair, - vec![ - "/ip6/::1/udp/0/quic-v1".parse().unwrap(), - "/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap(), - ], - ) - .unwrap(); - - let Some(Protocol::Udp(port)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let crypto_config1 = - Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); - let client_config1 = ClientConfig::new(crypto_config1); - let client1 = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - let connection1 = client1 - .connect_with(client_config1, format!("[::1]:{port}").parse().unwrap(), "l") - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - - let crypto_config2 = - Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); - let client_config2 = ClientConfig::new(crypto_config2); - let client2 = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - let connection2 = client2 - .connect_with(client_config2, format!("[::1]:{port}").parse().unwrap(), "l") - .map_err(|error| Error::Other(error.to_string())) - .unwrap(); - - tokio::spawn(async move { - match connection1.await { - Ok(connection) => Ok(connection), - Err(error) => Err(error), - } - }); - - tokio::spawn(async move { - match connection2.await { - Ok(connection) => Ok(connection), - Err(error) => Err(error), - } - }); - - for _ in 0..2 { - let _ = listener.next().await; - } - } + assert!(QuicListener::get_socket_address( + &"/ip6/::1/tcp/8888/quic-v1/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(QuicListener::get_socket_address( + &"/ip4/127.0.0.1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(QuicListener::get_socket_address( + &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(QuicListener::get_socket_address( + &"/dns/google.com/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(QuicListener::get_socket_address( + &"/ip6/::1/udp/8888/quic-v1/utp".parse().expect("valid multiaddress") + ) + .is_err()); + } + + #[tokio::test] + async fn no_listeners() { + let (mut listener, _) = QuicListener::new(&Keypair::generate(), Vec::new()).unwrap(); + + futures::future::poll_fn(|cx| match listener.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("unexpected event: {event:?}"), + }) + .await; + } + + #[tokio::test] + async fn one_listener() { + let address: Multiaddr = "/ip6/::1/udp/0/quic-v1".parse().unwrap(); + let keypair = Keypair::generate(); + let peer = PeerId::from_public_key(&keypair.public().into()); + let (mut listener, listen_addresses) = + QuicListener::new(&keypair, vec![address.clone()]).unwrap(); + let Some(Protocol::Udp(port)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let crypto_config = + Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); + let client_config = ClientConfig::new(crypto_config); + let client = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + let connection = client + .connect_with(client_config, format!("[::1]:{port}").parse().unwrap(), "l") + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + + let (res1, res2) = tokio::join!( + listener.next(), + Box::pin(async move { + match connection.await { + Ok(connection) => Ok(connection), + Err(error) => Err(error), + } + }) + ); + + assert!(res1.is_some() && res2.is_ok()); + } + + #[tokio::test] + async fn two_listeners() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let address1: Multiaddr = "/ip6/::1/udp/0/quic-v1".parse().unwrap(); + let address2: Multiaddr = "/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap(); + let keypair = Keypair::generate(); + let peer = PeerId::from_public_key(&keypair.public().into()); + + let (mut listener, listen_addresses) = + QuicListener::new(&keypair, vec![address1, address2]).unwrap(); + + let Some(Protocol::Udp(port1)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let Some(Protocol::Udp(port2)) = + listen_addresses.iter().skip(1).next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let crypto_config1 = + Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); + let client_config1 = ClientConfig::new(crypto_config1); + let client1 = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + let connection1 = client1 + .connect_with( + client_config1, + format!("[::1]:{port1}").parse().unwrap(), + "l", + ) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + + let crypto_config2 = + Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); + let client_config2 = ClientConfig::new(crypto_config2); + let client2 = Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + let connection2 = client2 + .connect_with( + client_config2, + format!("127.0.0.1:{port2}").parse().unwrap(), + "l", + ) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + + tokio::spawn(async move { + match connection1.await { + Ok(connection) => Ok(connection), + Err(error) => Err(error), + } + }); + + tokio::spawn(async move { + match connection2.await { + Ok(connection) => Ok(connection), + Err(error) => Err(error), + } + }); + + for _ in 0..2 { + let _ = listener.next().await; + } + } + + #[tokio::test] + async fn two_clients_dialing_same_address() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair = Keypair::generate(); + let peer = PeerId::from_public_key(&keypair.public().into()); + + let (mut listener, listen_addresses) = QuicListener::new( + &keypair, + vec![ + "/ip6/::1/udp/0/quic-v1".parse().unwrap(), + "/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap(), + ], + ) + .unwrap(); + + let Some(Protocol::Udp(port)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let crypto_config1 = + Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); + let client_config1 = ClientConfig::new(crypto_config1); + let client1 = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + let connection1 = client1 + .connect_with( + client_config1, + format!("[::1]:{port}").parse().unwrap(), + "l", + ) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + + let crypto_config2 = + Arc::new(make_client_config(&Keypair::generate(), Some(peer)).expect("to succeed")); + let client_config2 = ClientConfig::new(crypto_config2); + let client2 = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + let connection2 = client2 + .connect_with( + client_config2, + format!("[::1]:{port}").parse().unwrap(), + "l", + ) + .map_err(|error| Error::Other(error.to_string())) + .unwrap(); + + tokio::spawn(async move { + match connection1.await { + Ok(connection) => Ok(connection), + Err(error) => Err(error), + } + }); + + tokio::spawn(async move { + match connection2.await { + Ok(connection) => Ok(connection), + Err(error) => Err(error), + } + }); + + for _ in 0..2 { + let _ = listener.next().await; + } + } } diff --git a/src/transport/quic/mod.rs b/src/transport/quic/mod.rs index 81bd3ac0..6c015bfc 100644 --- a/src/transport/quic/mod.rs +++ b/src/transport/quic/mod.rs @@ -23,15 +23,15 @@ //! QUIC transport. use crate::{ - crypto::tls::make_client_config, - error::{AddressError, Error}, - transport::{ - manager::TransportHandle, - quic::{config::Config as QuicConfig, connection::QuicConnection, listener::QuicListener}, - Endpoint as Litep2pEndpoint, Transport, TransportBuilder, TransportEvent, - }, - types::ConnectionId, - PeerId, + crypto::tls::make_client_config, + error::{AddressError, Error}, + transport::{ + manager::TransportHandle, + quic::{config::Config as QuicConfig, connection::QuicConnection, listener::QuicListener}, + Endpoint as Litep2pEndpoint, Transport, TransportBuilder, TransportEvent, + }, + types::ConnectionId, + PeerId, }; use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; @@ -39,11 +39,11 @@ use multiaddr::{Multiaddr, Protocol}; use quinn::{ClientConfig, Connection, Endpoint, IdleTimeout}; use std::{ - collections::{HashMap, HashSet}, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, + collections::{HashMap, HashSet}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; pub(crate) use substream::Substream; @@ -59,489 +59,503 @@ const LOG_TARGET: &str = "litep2p::quic"; #[derive(Debug)] struct NegotiatedConnection { - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// QUIC connection. - connection: Connection, + /// QUIC connection. + connection: Connection, } /// QUIC transport object. pub(crate) struct QuicTransport { - /// Transport handle. - context: TransportHandle, + /// Transport handle. + context: TransportHandle, - /// Transport config. - config: QuicConfig, + /// Transport config. + config: QuicConfig, - /// QUIC listener. - listener: QuicListener, + /// QUIC listener. + listener: QuicListener, - /// Pending dials. - pending_dials: HashMap, + /// Pending dials. + pending_dials: HashMap, - /// Pending connections. - pending_connections: - FuturesUnordered)>>, + /// Pending connections. + pending_connections: + FuturesUnordered)>>, - /// Negotiated connections waiting for validation. - pending_open: HashMap, + /// Negotiated connections waiting for validation. + pending_open: HashMap, - /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture<'static, Result<(ConnectionId, Multiaddr, NegotiatedConnection), ConnectionId>>, - >, + /// Pending raw, unnegotiated connections. + pending_raw_connections: FuturesUnordered< + BoxFuture<'static, Result<(ConnectionId, Multiaddr, NegotiatedConnection), ConnectionId>>, + >, - /// Opened raw connection, waiting for approval/rejection from `TransportManager`. - opened_raw: HashMap, + /// Opened raw connection, waiting for approval/rejection from `TransportManager`. + opened_raw: HashMap, - /// Canceled raw connections. - canceled: HashSet, + /// Canceled raw connections. + canceled: HashSet, } impl QuicTransport { - /// Attempt to extract `PeerId` from connection certificates. - fn extract_peer_id(connection: &Connection) -> Option { - let certificates: Box> = - connection.peer_identity()?.downcast().ok()?; - let p2p_cert = crate::crypto::tls::certificate::parse(certificates.get(0)?) - .expect("the certificate was validated during TLS handshake; qed"); - - Some(p2p_cert.peer_id()) - } - - /// Handle established connection. - fn on_connection_established( - &mut self, - connection_id: ConnectionId, - result: crate::Result, - ) -> Option { - tracing::debug!(target: LOG_TARGET, ?connection_id, success = result.is_ok(), "connection established"); - - // `on_connection_established()` is called for both inbound and outbound connections - // but `pending_dials` will only contain entries for outbound connections. - let maybe_address = self.pending_dials.remove(&connection_id); - - match result { - Ok(connection) => { - let peer = connection.peer; - let endpoint = maybe_address.map_or( - { - let address = connection.connection.remote_address(); - Litep2pEndpoint::listener( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Udp(address.port())) - .with(Protocol::QuicV1), - connection_id, - ) - }, - |address| Litep2pEndpoint::dialer(address, connection_id), - ); - self.pending_open.insert(connection_id, (connection, endpoint.clone())); - - return Some(TransportEvent::ConnectionEstablished { peer, endpoint }); - }, - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?connection_id, ?error, "failed to establish connection"); - - // since the address was found from `pending_dials`, - // report the error to protocols and `TransportManager` - if let Some(address) = maybe_address { - return Some(TransportEvent::DialFailure { connection_id, address, error }); - } - }, - } - - None - } + /// Attempt to extract `PeerId` from connection certificates. + fn extract_peer_id(connection: &Connection) -> Option { + let certificates: Box> = + connection.peer_identity()?.downcast().ok()?; + let p2p_cert = crate::crypto::tls::certificate::parse(certificates.get(0)?) + .expect("the certificate was validated during TLS handshake; qed"); + + Some(p2p_cert.peer_id()) + } + + /// Handle established connection. + fn on_connection_established( + &mut self, + connection_id: ConnectionId, + result: crate::Result, + ) -> Option { + tracing::debug!(target: LOG_TARGET, ?connection_id, success = result.is_ok(), "connection established"); + + // `on_connection_established()` is called for both inbound and outbound connections + // but `pending_dials` will only contain entries for outbound connections. + let maybe_address = self.pending_dials.remove(&connection_id); + + match result { + Ok(connection) => { + let peer = connection.peer; + let endpoint = maybe_address.map_or( + { + let address = connection.connection.remote_address(); + Litep2pEndpoint::listener( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Udp(address.port())) + .with(Protocol::QuicV1), + connection_id, + ) + }, + |address| Litep2pEndpoint::dialer(address, connection_id), + ); + self.pending_open.insert(connection_id, (connection, endpoint.clone())); + + return Some(TransportEvent::ConnectionEstablished { peer, endpoint }); + } + Err(error) => { + tracing::debug!(target: LOG_TARGET, ?connection_id, ?error, "failed to establish connection"); + + // since the address was found from `pending_dials`, + // report the error to protocols and `TransportManager` + if let Some(address) = maybe_address { + return Some(TransportEvent::DialFailure { + connection_id, + address, + error, + }); + } + } + } + + None + } } impl TransportBuilder for QuicTransport { - type Config = QuicConfig; - type Transport = QuicTransport; - - /// Create new [`QuicTransport`] object. - fn new( - context: TransportHandle, - mut config: Self::Config, - ) -> crate::Result<(Self, Vec)> - where - Self: Sized, - { - tracing::info!( - target: LOG_TARGET, - ?config, - "start quic transport", - ); - - let (listener, listen_addresses) = QuicListener::new( - &context.keypair, - std::mem::replace(&mut config.listen_addresses, Vec::new()), - )?; - - Ok(( - Self { - context, - config, - listener, - canceled: HashSet::new(), - opened_raw: HashMap::new(), - pending_open: HashMap::new(), - pending_dials: HashMap::new(), - pending_raw_connections: FuturesUnordered::new(), - pending_connections: FuturesUnordered::new(), - }, - listen_addresses, - )) - } + type Config = QuicConfig; + type Transport = QuicTransport; + + /// Create new [`QuicTransport`] object. + fn new( + context: TransportHandle, + mut config: Self::Config, + ) -> crate::Result<(Self, Vec)> + where + Self: Sized, + { + tracing::info!( + target: LOG_TARGET, + ?config, + "start quic transport", + ); + + let (listener, listen_addresses) = QuicListener::new( + &context.keypair, + std::mem::replace(&mut config.listen_addresses, Vec::new()), + )?; + + Ok(( + Self { + context, + config, + listener, + canceled: HashSet::new(), + opened_raw: HashMap::new(), + pending_open: HashMap::new(), + pending_dials: HashMap::new(), + pending_raw_connections: FuturesUnordered::new(), + pending_connections: FuturesUnordered::new(), + }, + listen_addresses, + )) + } } impl Transport for QuicTransport { - fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { - let Ok((socket_address, Some(peer))) = QuicListener::get_socket_address(&address) else { - return Err(Error::AddressError(AddressError::PeerIdMissing)); - }; - - let crypto_config = - Arc::new(make_client_config(&self.context.keypair, Some(peer)).expect("to succeed")); - let mut transport_config = quinn::TransportConfig::default(); - let timeout = - IdleTimeout::try_from(self.config.connection_open_timeout).expect("to succeed"); - transport_config.max_idle_timeout(Some(timeout)); - let mut client_config = ClientConfig::new(crypto_config); - client_config.transport_config(Arc::new(transport_config)); - - let client_listen_address = match address.iter().next() { - Some(Protocol::Ip6(_)) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), - Some(Protocol::Ip4(_)) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - _ => return Err(Error::AddressError(AddressError::InvalidProtocol)), - }; - - let client = Endpoint::client(client_listen_address) - .map_err(|error| Error::Other(error.to_string()))?; - let connection = client - .connect_with(client_config, socket_address, "l") - .map_err(|error| Error::Other(error.to_string()))?; - - tracing::trace!( - target: LOG_TARGET, - ?address, - ?peer, - ?client_listen_address, - "dial peer", - ); - - self.pending_dials.insert(connection_id, address); - self.pending_connections.push(Box::pin(async move { - let connection = match connection.await { - Ok(connection) => connection, - Err(error) => return (connection_id, Err(error.into())), - }; - - let Some(peer) = Self::extract_peer_id(&connection) else { - return (connection_id, Err(Error::InvalidCertificate)); - }; - - (connection_id, Ok(NegotiatedConnection { peer, connection })) - })); - - Ok(()) - } - - fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let (connection, endpoint) = self - .pending_open - .remove(&connection_id) - .ok_or(Error::ConnectionDoesntExist(connection_id))?; - let bandwidth_sink = self.context.bandwidth_sink.clone(); - let protocol_set = self.context.protocol_set(connection_id); - let substream_open_timeout = self.config.substream_open_timeout; - - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - "start connection", - ); - - self.context.executor.run(Box::pin(async move { - let _ = QuicConnection::new( - connection.peer, - endpoint, - connection.connection, - protocol_set, - bandwidth_sink, - substream_open_timeout, - ) - .start() - .await; - })); - - Ok(()) - } - - fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - self.canceled.insert(connection_id); - self.pending_open - .remove(&connection_id) - .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(())) - } - - fn open( - &mut self, - connection_id: ConnectionId, - addresses: Vec, - ) -> crate::Result<()> { - let mut futures: FuturesUnordered<_> = addresses - .into_iter() - .map(|address| { - let keypair = self.context.keypair.clone(); - let connection_open_timeout = self.config.connection_open_timeout; - - async move { - let Ok((socket_address, Some(peer))) = - QuicListener::get_socket_address(&address) - else { - return ( - connection_id, - Err(Error::AddressError(AddressError::PeerIdMissing)), - ); - }; - - let crypto_config = - Arc::new(make_client_config(&keypair, Some(peer)).expect("to succeed")); - let mut transport_config = quinn::TransportConfig::default(); - let timeout = - IdleTimeout::try_from(connection_open_timeout).expect("to succeed"); - transport_config.max_idle_timeout(Some(timeout)); - let mut client_config = ClientConfig::new(crypto_config); - client_config.transport_config(Arc::new(transport_config)); - - let client_listen_address = match address.iter().next() { - Some(Protocol::Ip6(_)) => - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), - Some(Protocol::Ip4(_)) => - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), - _ => - return ( - connection_id, - Err(Error::AddressError(AddressError::InvalidProtocol)), - ), - }; - - let client = match Endpoint::client(client_listen_address) { - Ok(client) => client, - Err(error) => { - return (connection_id, Err(Error::Other(error.to_string()))); - }, - }; - let connection = match client.connect_with(client_config, socket_address, "l") { - Ok(connection) => connection, - Err(error) => { - return (connection_id, Err(Error::Other(error.to_string()))); - }, - }; - - let connection = match connection.await { - Ok(connection) => connection, - Err(error) => return (connection_id, Err(error.into())), - }; - - let Some(peer) = Self::extract_peer_id(&connection) else { - return (connection_id, Err(Error::InvalidCertificate)); - }; - - (connection_id, Ok((address, NegotiatedConnection { peer, connection }))) - } - }) - .collect(); - - self.pending_raw_connections.push(Box::pin(async move { - while let Some(result) = futures.next().await { - let (connection_id, result) = result; - - match result { - Ok((address, connection)) => return Ok((connection_id, address, connection)), - Err(error) => tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "failed to open connection", - ), - } - } - - Err(connection_id) - })); - - Ok(()) - } - - fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let (connection, _address) = self - .opened_raw - .remove(&connection_id) - .ok_or(Error::ConnectionDoesntExist(connection_id))?; - - self.pending_connections - .push(Box::pin(async move { (connection_id, Ok(connection)) })); - - Ok(()) - } - - /// Cancel opening connections. - fn cancel(&mut self, connection_id: ConnectionId) { - self.canceled.insert(connection_id); - } + fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { + let Ok((socket_address, Some(peer))) = QuicListener::get_socket_address(&address) else { + return Err(Error::AddressError(AddressError::PeerIdMissing)); + }; + + let crypto_config = + Arc::new(make_client_config(&self.context.keypair, Some(peer)).expect("to succeed")); + let mut transport_config = quinn::TransportConfig::default(); + let timeout = + IdleTimeout::try_from(self.config.connection_open_timeout).expect("to succeed"); + transport_config.max_idle_timeout(Some(timeout)); + let mut client_config = ClientConfig::new(crypto_config); + client_config.transport_config(Arc::new(transport_config)); + + let client_listen_address = match address.iter().next() { + Some(Protocol::Ip6(_)) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + Some(Protocol::Ip4(_)) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + _ => return Err(Error::AddressError(AddressError::InvalidProtocol)), + }; + + let client = Endpoint::client(client_listen_address) + .map_err(|error| Error::Other(error.to_string()))?; + let connection = client + .connect_with(client_config, socket_address, "l") + .map_err(|error| Error::Other(error.to_string()))?; + + tracing::trace!( + target: LOG_TARGET, + ?address, + ?peer, + ?client_listen_address, + "dial peer", + ); + + self.pending_dials.insert(connection_id, address); + self.pending_connections.push(Box::pin(async move { + let connection = match connection.await { + Ok(connection) => connection, + Err(error) => return (connection_id, Err(error.into())), + }; + + let Some(peer) = Self::extract_peer_id(&connection) else { + return (connection_id, Err(Error::InvalidCertificate)); + }; + + (connection_id, Ok(NegotiatedConnection { peer, connection })) + })); + + Ok(()) + } + + fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let (connection, endpoint) = self + .pending_open + .remove(&connection_id) + .ok_or(Error::ConnectionDoesntExist(connection_id))?; + let bandwidth_sink = self.context.bandwidth_sink.clone(); + let protocol_set = self.context.protocol_set(connection_id); + let substream_open_timeout = self.config.substream_open_timeout; + + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + "start connection", + ); + + self.context.executor.run(Box::pin(async move { + let _ = QuicConnection::new( + connection.peer, + endpoint, + connection.connection, + protocol_set, + bandwidth_sink, + substream_open_timeout, + ) + .start() + .await; + })); + + Ok(()) + } + + fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + self.canceled.insert(connection_id); + self.pending_open + .remove(&connection_id) + .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(())) + } + + fn open( + &mut self, + connection_id: ConnectionId, + addresses: Vec, + ) -> crate::Result<()> { + let mut futures: FuturesUnordered<_> = addresses + .into_iter() + .map(|address| { + let keypair = self.context.keypair.clone(); + let connection_open_timeout = self.config.connection_open_timeout; + + async move { + let Ok((socket_address, Some(peer))) = + QuicListener::get_socket_address(&address) + else { + return ( + connection_id, + Err(Error::AddressError(AddressError::PeerIdMissing)), + ); + }; + + let crypto_config = + Arc::new(make_client_config(&keypair, Some(peer)).expect("to succeed")); + let mut transport_config = quinn::TransportConfig::default(); + let timeout = + IdleTimeout::try_from(connection_open_timeout).expect("to succeed"); + transport_config.max_idle_timeout(Some(timeout)); + let mut client_config = ClientConfig::new(crypto_config); + client_config.transport_config(Arc::new(transport_config)); + + let client_listen_address = match address.iter().next() { + Some(Protocol::Ip6(_)) => + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + Some(Protocol::Ip4(_)) => + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + _ => + return ( + connection_id, + Err(Error::AddressError(AddressError::InvalidProtocol)), + ), + }; + + let client = match Endpoint::client(client_listen_address) { + Ok(client) => client, + Err(error) => { + return (connection_id, Err(Error::Other(error.to_string()))); + } + }; + let connection = match client.connect_with(client_config, socket_address, "l") { + Ok(connection) => connection, + Err(error) => { + return (connection_id, Err(Error::Other(error.to_string()))); + } + }; + + let connection = match connection.await { + Ok(connection) => connection, + Err(error) => return (connection_id, Err(error.into())), + }; + + let Some(peer) = Self::extract_peer_id(&connection) else { + return (connection_id, Err(Error::InvalidCertificate)); + }; + + ( + connection_id, + Ok((address, NegotiatedConnection { peer, connection })), + ) + } + }) + .collect(); + + self.pending_raw_connections.push(Box::pin(async move { + while let Some(result) = futures.next().await { + let (connection_id, result) = result; + + match result { + Ok((address, connection)) => return Ok((connection_id, address, connection)), + Err(error) => tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "failed to open connection", + ), + } + } + + Err(connection_id) + })); + + Ok(()) + } + + fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let (connection, _address) = self + .opened_raw + .remove(&connection_id) + .ok_or(Error::ConnectionDoesntExist(connection_id))?; + + self.pending_connections + .push(Box::pin(async move { (connection_id, Ok(connection)) })); + + Ok(()) + } + + /// Cancel opening connections. + fn cancel(&mut self, connection_id: ConnectionId) { + self.canceled.insert(connection_id); + } } impl Stream for QuicTransport { - type Item = TransportEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - while let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) { - let connection_id = self.context.next_connection_id(); - - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - "accept connection", - ); - - self.pending_connections.push(Box::pin(async move { - let connection = match connection.await { - Ok(connection) => connection, - Err(error) => return (connection_id, Err(error.into())), - }; - - let Some(peer) = Self::extract_peer_id(&connection) else { - return (connection_id, Err(Error::InvalidCertificate)); - }; - - (connection_id, Ok(NegotiatedConnection { peer, connection })) - })); - } - - while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); - - if !self.canceled.remove(&connection_id) { - self.opened_raw.insert(connection_id, (stream, address.clone())); - - return Poll::Ready(Some(TransportEvent::ConnectionOpened { - connection_id, - address, - })); - } - }, - Err(connection_id) => - if !self.canceled.remove(&connection_id) { - return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id })); - }, - } - } - - while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) { - let (connection_id, result) = connection; - - match self.on_connection_established(connection_id, result) { - Some(event) => return Poll::Ready(Some(event)), - None => {}, - } - } - - Poll::Pending - } + type Item = TransportEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) { + let connection_id = self.context.next_connection_id(); + + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + "accept connection", + ); + + self.pending_connections.push(Box::pin(async move { + let connection = match connection.await { + Ok(connection) => connection, + Err(error) => return (connection_id, Err(error.into())), + }; + + let Some(peer) = Self::extract_peer_id(&connection) else { + return (connection_id, Err(Error::InvalidCertificate)); + }; + + (connection_id, Ok(NegotiatedConnection { peer, connection })) + })); + } + + while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { + match result { + Ok((connection_id, address, stream)) => { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?address, + canceled = self.canceled.contains(&connection_id), + "connection opened", + ); + + if !self.canceled.remove(&connection_id) { + self.opened_raw.insert(connection_id, (stream, address.clone())); + + return Poll::Ready(Some(TransportEvent::ConnectionOpened { + connection_id, + address, + })); + } + } + Err(connection_id) => + if !self.canceled.remove(&connection_id) { + return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id })); + }, + } + } + + while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) { + let (connection_id, result) = connection; + + match self.on_connection_established(connection_id, result) { + Some(event) => return Poll::Ready(Some(event)), + None => {} + } + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - codec::ProtocolCodec, - crypto::ed25519::Keypair, - executor::DefaultExecutor, - transport::manager::{ProtocolContext, TransportHandle}, - types::protocol::ProtocolName, - BandwidthSink, - }; - use multihash::Multihash; - use tokio::sync::mpsc::channel; - - #[tokio::test] - async fn test_quinn() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let keypair1 = Keypair::generate(); - let (tx1, _rx1) = channel(64); - let (event_tx1, _event_rx1) = channel(64); - - let handle1 = TransportHandle { - executor: Arc::new(DefaultExecutor {}), - protocol_names: Vec::new(), - next_substream_id: Default::default(), - next_connection_id: Default::default(), - keypair: keypair1.clone(), - tx: event_tx1, - bandwidth_sink: BandwidthSink::new(), - - protocols: HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx1, - codec: ProtocolCodec::Identity(32), - fallback_names: Vec::new(), - }, - )]), - }; - - let (mut transport1, listen_addresses) = - QuicTransport::new(handle1, Default::default()).unwrap(); - let listen_address = listen_addresses[0].clone(); - - let keypair2 = Keypair::generate(); - let (tx2, _rx2) = channel(64); - let (event_tx2, _event_rx2) = channel(64); - - let handle2 = TransportHandle { - executor: Arc::new(DefaultExecutor {}), - protocol_names: Vec::new(), - next_substream_id: Default::default(), - next_connection_id: Default::default(), - keypair: keypair2.clone(), - tx: event_tx2, - bandwidth_sink: BandwidthSink::new(), - - protocols: HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx2, - codec: ProtocolCodec::Identity(32), - fallback_names: Vec::new(), - }, - )]), - }; - - let (mut transport2, _) = QuicTransport::new(handle2, Default::default()).unwrap(); - let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into()); - let _peer2: PeerId = PeerId::from_public_key(&keypair2.public().into()); - let listen_address = - listen_address.with(Protocol::P2p(Multihash::from_bytes(&peer1.to_bytes()).unwrap())); - - transport2.dial(ConnectionId::new(), listen_address).unwrap(); - let (res1, res2) = tokio::join!(transport1.next(), transport2.next()); - - assert!(std::matches!(res1, Some(TransportEvent::ConnectionEstablished { .. }))); - assert!(std::matches!(res2, Some(TransportEvent::ConnectionEstablished { .. }))); - } + use super::*; + use crate::{ + codec::ProtocolCodec, + crypto::ed25519::Keypair, + executor::DefaultExecutor, + transport::manager::{ProtocolContext, TransportHandle}, + types::protocol::ProtocolName, + BandwidthSink, + }; + use multihash::Multihash; + use tokio::sync::mpsc::channel; + + #[tokio::test] + async fn test_quinn() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair1 = Keypair::generate(); + let (tx1, _rx1) = channel(64); + let (event_tx1, _event_rx1) = channel(64); + + let handle1 = TransportHandle { + executor: Arc::new(DefaultExecutor {}), + protocol_names: Vec::new(), + next_substream_id: Default::default(), + next_connection_id: Default::default(), + keypair: keypair1.clone(), + tx: event_tx1, + bandwidth_sink: BandwidthSink::new(), + + protocols: HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx1, + codec: ProtocolCodec::Identity(32), + fallback_names: Vec::new(), + }, + )]), + }; + + let (mut transport1, listen_addresses) = + QuicTransport::new(handle1, Default::default()).unwrap(); + let listen_address = listen_addresses[0].clone(); + + let keypair2 = Keypair::generate(); + let (tx2, _rx2) = channel(64); + let (event_tx2, _event_rx2) = channel(64); + + let handle2 = TransportHandle { + executor: Arc::new(DefaultExecutor {}), + protocol_names: Vec::new(), + next_substream_id: Default::default(), + next_connection_id: Default::default(), + keypair: keypair2.clone(), + tx: event_tx2, + bandwidth_sink: BandwidthSink::new(), + + protocols: HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx2, + codec: ProtocolCodec::Identity(32), + fallback_names: Vec::new(), + }, + )]), + }; + + let (mut transport2, _) = QuicTransport::new(handle2, Default::default()).unwrap(); + let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into()); + let _peer2: PeerId = PeerId::from_public_key(&keypair2.public().into()); + let listen_address = listen_address.with(Protocol::P2p( + Multihash::from_bytes(&peer1.to_bytes()).unwrap(), + )); + + transport2.dial(ConnectionId::new(), listen_address).unwrap(); + let (res1, res2) = tokio::join!(transport1.next(), transport2.next()); + + assert!(std::matches!( + res1, + Some(TransportEvent::ConnectionEstablished { .. }) + )); + assert!(std::matches!( + res2, + Some(TransportEvent::ConnectionEstablished { .. }) + )); + } } diff --git a/src/transport/quic/substream.rs b/src/transport/quic/substream.rs index 3dc8ed83..826888b3 100644 --- a/src/transport/quic/substream.rs +++ b/src/transport/quic/substream.rs @@ -19,8 +19,8 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - error::{Error, SubstreamError}, - BandwidthSink, + error::{Error, SubstreamError}, + BandwidthSink, }; use bytes::Bytes; @@ -30,9 +30,9 @@ use tokio::io::{AsyncRead as TokioAsyncRead, AsyncWrite as TokioAsyncWrite}; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use std::{ - io, - pin::Pin, - task::{Context, Poll}, + io, + pin::Pin, + task::{Context, Poll}, }; use crate::protocol::Permit; @@ -40,133 +40,138 @@ use crate::protocol::Permit; /// QUIC substream. #[derive(Debug)] pub struct Substream { - _permit: Permit, - bandwidth_sink: BandwidthSink, - send_stream: SendStream, - recv_stream: RecvStream, + _permit: Permit, + bandwidth_sink: BandwidthSink, + send_stream: SendStream, + recv_stream: RecvStream, } impl Substream { - /// Create new [`Substream`]. - pub fn new( - _permit: Permit, - send_stream: SendStream, - recv_stream: RecvStream, - bandwidth_sink: BandwidthSink, - ) -> Self { - Self { _permit, send_stream, recv_stream, bandwidth_sink } - } - - /// Write `buffers` to the underlying socket. - pub async fn write_all_chunks(&mut self, buffers: &mut [Bytes]) -> crate::Result<()> { - let nwritten = buffers.iter().fold(0usize, |acc, buffer| acc + buffer.len()); - - match self - .send_stream - .write_all_chunks(buffers) - .await - .map_err(|_| Error::SubstreamError(SubstreamError::ConnectionClosed)) - { - Ok(()) => { - self.bandwidth_sink.increase_outbound(nwritten); - Ok(()) - }, - Err(error) => return Err(error), - } - } + /// Create new [`Substream`]. + pub fn new( + _permit: Permit, + send_stream: SendStream, + recv_stream: RecvStream, + bandwidth_sink: BandwidthSink, + ) -> Self { + Self { + _permit, + send_stream, + recv_stream, + bandwidth_sink, + } + } + + /// Write `buffers` to the underlying socket. + pub async fn write_all_chunks(&mut self, buffers: &mut [Bytes]) -> crate::Result<()> { + let nwritten = buffers.iter().fold(0usize, |acc, buffer| acc + buffer.len()); + + match self + .send_stream + .write_all_chunks(buffers) + .await + .map_err(|_| Error::SubstreamError(SubstreamError::ConnectionClosed)) + { + Ok(()) => { + self.bandwidth_sink.increase_outbound(nwritten); + Ok(()) + } + Err(error) => return Err(error), + } + } } impl TokioAsyncRead for Substream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.recv_stream).poll_read(cx, buf)) { - Err(error) => Poll::Ready(Err(error)), - Ok(res) => { - self.bandwidth_sink.increase_inbound(buf.filled().len()); - Poll::Ready(Ok(res)) - }, - } - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.recv_stream).poll_read(cx, buf)) { + Err(error) => Poll::Ready(Err(error)), + Ok(res) => { + self.bandwidth_sink.increase_inbound(buf.filled().len()); + Poll::Ready(Ok(res)) + } + } + } } impl TokioAsyncWrite for Substream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.send_stream).poll_write(cx, buf)) { - Err(error) => Poll::Ready(Err(error)), - Ok(nwritten) => { - self.bandwidth_sink.increase_outbound(nwritten); - Poll::Ready(Ok(nwritten)) - }, - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send_stream).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.send_stream).poll_shutdown(cx) - } + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.send_stream).poll_write(cx, buf)) { + Err(error) => Poll::Ready(Err(error)), + Ok(nwritten) => { + self.bandwidth_sink.increase_outbound(nwritten); + Poll::Ready(Ok(nwritten)) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.send_stream).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.send_stream).poll_shutdown(cx) + } } /// Substream pair used to negotiate a protocol for the connection. pub struct NegotiatingSubstream { - recv_stream: Compat, - send_stream: Compat, + recv_stream: Compat, + send_stream: Compat, } impl NegotiatingSubstream { - /// Create new [`NegotiatingSubstream`]. - pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self { - Self { - recv_stream: TokioAsyncReadCompatExt::compat(recv_stream), - send_stream: TokioAsyncWriteCompatExt::compat_write(send_stream), - } - } - - /// Deconstruct [`NegotiatingSubstream`] into parts. - pub fn into_parts(self) -> (SendStream, RecvStream) { - let sender = self.send_stream.into_inner(); - let receiver = self.recv_stream.into_inner(); - - (sender, receiver) - } + /// Create new [`NegotiatingSubstream`]. + pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self { + Self { + recv_stream: TokioAsyncReadCompatExt::compat(recv_stream), + send_stream: TokioAsyncWriteCompatExt::compat_write(send_stream), + } + } + + /// Deconstruct [`NegotiatingSubstream`] into parts. + pub fn into_parts(self) -> (SendStream, RecvStream) { + let sender = self.send_stream.into_inner(); + let receiver = self.recv_stream.into_inner(); + + (sender, receiver) + } } impl AsyncRead for NegotiatingSubstream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - Pin::new(&mut self.recv_stream).poll_read(cx, buf) - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.recv_stream).poll_read(cx, buf) + } } impl AsyncWrite for NegotiatingSubstream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.send_stream).poll_write(cx, buf) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send_stream).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send_stream).poll_close(cx) - } + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.send_stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.send_stream).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.send_stream).poll_close(cx) + } } diff --git a/src/transport/tcp/config.rs b/src/transport/tcp/config.rs index c146dd8a..cd4926b2 100644 --- a/src/transport/tcp/config.rs +++ b/src/transport/tcp/config.rs @@ -21,75 +21,75 @@ //! TCP transport configuration. use crate::{ - crypto::noise::{MAX_READ_AHEAD_FACTOR, MAX_WRITE_BUFFER_SIZE}, - transport::{CONNECTION_OPEN_TIMEOUT, SUBSTREAM_OPEN_TIMEOUT}, + crypto::noise::{MAX_READ_AHEAD_FACTOR, MAX_WRITE_BUFFER_SIZE}, + transport::{CONNECTION_OPEN_TIMEOUT, SUBSTREAM_OPEN_TIMEOUT}, }; /// TCP transport configuration. #[derive(Debug, Clone)] pub struct Config { - /// Listen address for the transport. - /// - /// Default listen addresses are ["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"]. - pub listen_addresses: Vec, + /// Listen address for the transport. + /// + /// Default listen addresses are ["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"]. + pub listen_addresses: Vec, - /// Whether to set `SO_REUSEPORT` and bind a socket to the listen address port for outbound - /// connections. - /// - /// Note that `SO_REUSEADDR` is always set on listening sockets. - /// - /// Defaults to `true`. - pub reuse_port: bool, + /// Whether to set `SO_REUSEPORT` and bind a socket to the listen address port for outbound + /// connections. + /// + /// Note that `SO_REUSEADDR` is always set on listening sockets. + /// + /// Defaults to `true`. + pub reuse_port: bool, - /// Yamux configuration. - pub yamux_config: crate::yamux::Config, + /// Yamux configuration. + pub yamux_config: crate::yamux::Config, - /// Noise read-ahead frame count. - /// - /// Specifies how many Noise frames are read per call to the underlying socket. - /// - /// By default this is configured to `5` so each call to the underlying socket can read up - /// to `5` Noise frame per call. Fewer frames may be read if there isn't enough data in the - /// socket. Each Noise frame is `65 KB` so the default setting allocates `65 KB * 5 = 325 KB` - /// per connection. - pub noise_read_ahead_frame_count: usize, + /// Noise read-ahead frame count. + /// + /// Specifies how many Noise frames are read per call to the underlying socket. + /// + /// By default this is configured to `5` so each call to the underlying socket can read up + /// to `5` Noise frame per call. Fewer frames may be read if there isn't enough data in the + /// socket. Each Noise frame is `65 KB` so the default setting allocates `65 KB * 5 = 325 KB` + /// per connection. + pub noise_read_ahead_frame_count: usize, - /// Noise write buffer size. - /// - /// Specifes how many Noise frames are tried to be coalesced into a single system call. - /// By default the value is set to `2` which means that the `NoiseSocket` will allocate - /// `130 KB` for each outgoing connection. - /// - /// The write buffer size is separate from the read-ahead frame count so by default - /// the Noise code will allocate `2 * 65 KB + 5 * 65 KB = 455 KB` per connection. - pub noise_write_buffer_size: usize, + /// Noise write buffer size. + /// + /// Specifes how many Noise frames are tried to be coalesced into a single system call. + /// By default the value is set to `2` which means that the `NoiseSocket` will allocate + /// `130 KB` for each outgoing connection. + /// + /// The write buffer size is separate from the read-ahead frame count so by default + /// the Noise code will allocate `2 * 65 KB + 5 * 65 KB = 455 KB` per connection. + pub noise_write_buffer_size: usize, - /// Connection open timeout. - /// - /// How long should litep2p wait for a connection to be opend before the host - /// is deemed unreachable. - pub connection_open_timeout: std::time::Duration, + /// Connection open timeout. + /// + /// How long should litep2p wait for a connection to be opend before the host + /// is deemed unreachable. + pub connection_open_timeout: std::time::Duration, - /// Substream open timeout. - /// - /// How long should litep2p wait for a substream to be opened before considering - /// the substream rejected. - pub substream_open_timeout: std::time::Duration, + /// Substream open timeout. + /// + /// How long should litep2p wait for a substream to be opened before considering + /// the substream rejected. + pub substream_open_timeout: std::time::Duration, } impl Default for Config { - fn default() -> Self { - Self { - listen_addresses: vec![ - "/ip4/0.0.0.0/tcp/0".parse().expect("valid address"), - "/ip6/::/tcp/0".parse().expect("valid address"), - ], - reuse_port: true, - yamux_config: Default::default(), - noise_read_ahead_frame_count: MAX_READ_AHEAD_FACTOR, - noise_write_buffer_size: MAX_WRITE_BUFFER_SIZE, - connection_open_timeout: CONNECTION_OPEN_TIMEOUT, - substream_open_timeout: SUBSTREAM_OPEN_TIMEOUT, - } - } + fn default() -> Self { + Self { + listen_addresses: vec![ + "/ip4/0.0.0.0/tcp/0".parse().expect("valid address"), + "/ip6/::/tcp/0".parse().expect("valid address"), + ], + reuse_port: true, + yamux_config: Default::default(), + noise_read_ahead_frame_count: MAX_READ_AHEAD_FACTOR, + noise_write_buffer_size: MAX_WRITE_BUFFER_SIZE, + connection_open_timeout: CONNECTION_OPEN_TIMEOUT, + substream_open_timeout: SUBSTREAM_OPEN_TIMEOUT, + } + } } diff --git a/src/transport/tcp/connection.rs b/src/transport/tcp/connection.rs index b0d9f247..50fc0550 100644 --- a/src/transport/tcp/connection.rs +++ b/src/transport/tcp/connection.rs @@ -19,43 +19,43 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - config::Role, - crypto::{ - ed25519::Keypair, - noise::{self, NoiseSocket}, - }, - error::{Error, NegotiationError}, - multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, - protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, - substream, - transport::{ - tcp::{listener::AddressType, substream::Substream}, - Endpoint, - }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - BandwidthSink, PeerId, + config::Role, + crypto::{ + ed25519::Keypair, + noise::{self, NoiseSocket}, + }, + error::{Error, NegotiationError}, + multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, + protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, + substream, + transport::{ + tcp::{listener::AddressType, substream::Substream}, + Endpoint, + }, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + BandwidthSink, PeerId, }; use futures::{ - future::BoxFuture, - stream::{FuturesUnordered, StreamExt}, - AsyncRead, AsyncWrite, + future::BoxFuture, + stream::{FuturesUnordered, StreamExt}, + AsyncRead, AsyncWrite, }; use multiaddr::{Multiaddr, Protocol}; use tokio::net::TcpStream; use tokio_util::compat::{ - Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt, + Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt, }; use std::{ - borrow::Cow, - fmt, - net::SocketAddr, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, - time::Duration, + borrow::Cow, + fmt, + net::SocketAddr, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, }; /// Logging target for the file. @@ -63,1176 +63,1188 @@ const LOG_TARGET: &str = "litep2p::tcp::connection"; #[derive(Debug)] pub struct NegotiatedSubstream { - /// Substream direction. - direction: Direction, + /// Substream direction. + direction: Direction, - /// Substream ID. - substream_id: SubstreamId, + /// Substream ID. + substream_id: SubstreamId, - /// Protocol name. - protocol: ProtocolName, + /// Protocol name. + protocol: ProtocolName, - /// Yamux substream. - io: crate::yamux::Stream, + /// Yamux substream. + io: crate::yamux::Stream, - /// Permit. - permit: Permit, + /// Permit. + permit: Permit, } /// TCP connection error. #[derive(Debug)] enum ConnectionError { - /// Timeout - Timeout { - /// Protocol. - protocol: Option, - - /// Substream ID. - substream_id: Option, - }, - - /// Failed to negotiate connection/substream. - FailedToNegotiate { - /// Protocol. - protocol: Option, - - /// Substream ID. - substream_id: Option, - - /// Error. - error: Error, - }, + /// Timeout + Timeout { + /// Protocol. + protocol: Option, + + /// Substream ID. + substream_id: Option, + }, + + /// Failed to negotiate connection/substream. + FailedToNegotiate { + /// Protocol. + protocol: Option, + + /// Substream ID. + substream_id: Option, + + /// Error. + error: Error, + }, } /// Connection context for an opened connection that hasn't yet started its event loop. pub struct NegotiatedConnection { - /// Yamux connection. - connection: crate::yamux::ControlledConnection>>, + /// Yamux connection. + connection: crate::yamux::ControlledConnection>>, - /// Yamux control. - control: crate::yamux::Control, + /// Yamux control. + control: crate::yamux::Control, - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Endpoint. - endpoint: Endpoint, + /// Endpoint. + endpoint: Endpoint, - /// Substream open timeout. - substream_open_timeout: Duration, + /// Substream open timeout. + substream_open_timeout: Duration, } impl NegotiatedConnection { - /// Get `ConnectionId` of the negotiated connection. - pub fn connection_id(&self) -> ConnectionId { - self.endpoint.connection_id() - } - - /// Get `PeerId` of the negotiated connection. - pub fn peer(&self) -> PeerId { - self.peer - } - - /// Get `Endpoint` of the negotiated connection. - pub fn endpoint(&self) -> Endpoint { - self.endpoint.clone() - } + /// Get `ConnectionId` of the negotiated connection. + pub fn connection_id(&self) -> ConnectionId { + self.endpoint.connection_id() + } + + /// Get `PeerId` of the negotiated connection. + pub fn peer(&self) -> PeerId { + self.peer + } + + /// Get `Endpoint` of the negotiated connection. + pub fn endpoint(&self) -> Endpoint { + self.endpoint.clone() + } } /// TCP connection. pub struct TcpConnection { - /// Protocol context. - protocol_set: ProtocolSet, + /// Protocol context. + protocol_set: ProtocolSet, - /// Yamux connection. - connection: crate::yamux::ControlledConnection>>, + /// Yamux connection. + connection: crate::yamux::ControlledConnection>>, - /// Yamux control. - control: crate::yamux::Control, + /// Yamux control. + control: crate::yamux::Control, - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Endpoint. - endpoint: Endpoint, + /// Endpoint. + endpoint: Endpoint, - /// Substream open timeout. - substream_open_timeout: Duration, + /// Substream open timeout. + substream_open_timeout: Duration, - /// Next substream ID. - next_substream_id: Arc, + /// Next substream ID. + next_substream_id: Arc, - // Bandwidth sink. - bandwidth_sink: BandwidthSink, + // Bandwidth sink. + bandwidth_sink: BandwidthSink, - /// Pending substreams. - pending_substreams: - FuturesUnordered>>, + /// Pending substreams. + pending_substreams: + FuturesUnordered>>, } impl fmt::Debug for TcpConnection { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TcpConnection") - .field("peer", &self.peer) - .field("next_substream_id", &self.next_substream_id) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TcpConnection") + .field("peer", &self.peer) + .field("next_substream_id", &self.next_substream_id) + .finish() + } } impl TcpConnection { - /// Create new [`TcpConnection`] from [`NegotiatedConnection`]. - pub(super) fn new( - context: NegotiatedConnection, - protocol_set: ProtocolSet, - bandwidth_sink: BandwidthSink, - next_substream_id: Arc, - ) -> Self { - let NegotiatedConnection { connection, control, peer, endpoint, substream_open_timeout } = - context; - - Self { - protocol_set, - connection, - control, - peer, - endpoint, - bandwidth_sink, - next_substream_id, - pending_substreams: FuturesUnordered::new(), - substream_open_timeout, - } - } - - /// Open connection to remote peer at `address`. - // TODO: this function can be removed - pub(super) async fn open_connection( - connection_id: ConnectionId, - keypair: Keypair, - stream: TcpStream, - address: AddressType, - peer: Option, - yamux_config: crate::yamux::Config, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - connection_open_timeout: Duration, - substream_open_timeout: Duration, - ) -> crate::Result { - tracing::debug!( - target: LOG_TARGET, - ?address, - ?peer, - "open connection to remote peer", - ); - - match tokio::time::timeout(connection_open_timeout, async move { - Self::negotiate_connection( - stream, - peer, - connection_id, - keypair, - Role::Dialer, - address, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - substream_open_timeout, - ) - .await - }) - .await - { - Err(_) => Err(Error::Timeout), - Ok(result) => result, - } - } - - /// Open substream for `protocol`. - pub(super) async fn open_substream( - mut control: crate::yamux::Control, - substream_id: SubstreamId, - permit: Permit, - protocol: ProtocolName, - fallback_names: Vec, - open_timeout: Duration, - ) -> crate::Result { - tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); - - let stream = match control.open_stream().await { - Ok(stream) => { - tracing::trace!(target: LOG_TARGET, ?substream_id, "substream opened"); - stream - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?substream_id, - ?error, - "failed to open substream" - ); - return Err(Error::YamuxError(Direction::Outbound(substream_id), error)); - }, - }; - - // TODO: protocols don't change after they've been initialized so this should be done only - // once - let protocols = std::iter::once(&*protocol) - .chain(fallback_names.iter().map(|protocol| &**protocol)) - .collect(); - - let (io, protocol) = - Self::negotiate_protocol(stream, &Role::Dialer, protocols, open_timeout).await?; - - Ok(NegotiatedSubstream { - io: io.inner(), - substream_id, - direction: Direction::Outbound(substream_id), - protocol, - permit, - }) - } - - /// Accept a new connection. - pub(super) async fn accept_connection( - stream: TcpStream, - connection_id: ConnectionId, - keypair: Keypair, - address: SocketAddr, - yamux_config: crate::yamux::Config, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - connection_open_timeout: Duration, - substream_open_timeout: Duration, - ) -> crate::Result { - tracing::debug!(target: LOG_TARGET, ?address, "accept connection"); - - match tokio::time::timeout(connection_open_timeout, async move { - Self::negotiate_connection( - stream, - None, - connection_id, - keypair, - Role::Listener, - AddressType::Socket(address), - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - substream_open_timeout, - ) - .await - }) - .await - { - Err(_) => return Err(Error::Timeout), - Ok(result) => result, - } - } - - /// Accept substream. - pub(super) async fn accept_substream( - stream: crate::yamux::Stream, - permit: Permit, - substream_id: SubstreamId, - protocols: Vec, - open_timeout: Duration, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - "accept inbound substream", - ); - - let protocols = protocols.iter().map(|protocol| &**protocol).collect::>(); - let (io, protocol) = - Self::negotiate_protocol(stream, &Role::Listener, protocols, open_timeout).await?; - - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - "substream accepted and negotiated", - ); - - Ok(NegotiatedSubstream { - io: io.inner(), - substream_id, - direction: Direction::Inbound, - protocol, - permit, - }) - } - - /// Negotiate protocol. - async fn negotiate_protocol( - stream: S, - role: &Role, - protocols: Vec<&str>, - substream_open_timeout: Duration, - ) -> crate::Result<(Negotiated, ProtocolName)> { - tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); - - match tokio::time::timeout(substream_open_timeout, async move { - match role { - Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await, - Role::Listener => listener_select_proto(stream, protocols).await, - } - }) - .await - { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError(error))), - Ok(Ok((protocol, socket))) => { - tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); - - Ok((socket, ProtocolName::from(protocol.to_string()))) - }, - } - } - - /// Negotiate noise + yamux for the connection. - pub(super) async fn negotiate_connection( - stream: TcpStream, - dialed_peer: Option, - connection_id: ConnectionId, - keypair: Keypair, - role: Role, - address: AddressType, - yamux_config: crate::yamux::Config, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - substream_open_timeout: Duration, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?role, - "negotiate connection", - ); - - let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); - let stream = TokioAsyncWriteCompatExt::compat_write(stream); - - // negotiate `noise` - let (stream, _) = - Self::negotiate_protocol(stream, &role, vec!["/noise"], substream_open_timeout).await?; - - tracing::trace!( - target: LOG_TARGET, - "`multistream-select` and `noise` negotiated", - ); - - // perform noise handshake - let (stream, peer) = noise::handshake( - stream.inner(), - &keypair, - role, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await?; - - if let Some(dialed_peer) = dialed_peer { - if dialed_peer != peer { - tracing::debug!(target: LOG_TARGET, ?dialed_peer, ?peer, "peer id mismatch"); - return Err(Error::PeerIdMismatch(dialed_peer, peer)); - } - } - - tracing::trace!(target: LOG_TARGET, "noise handshake done"); - let stream: NoiseSocket> = stream; - - // negotiate `yamux` - let (stream, _) = - Self::negotiate_protocol(stream, &role, vec!["/yamux/1.0.0"], substream_open_timeout) - .await?; - tracing::trace!(target: LOG_TARGET, "`yamux` negotiated"); - - let connection = crate::yamux::Connection::new(stream.inner(), yamux_config, role.into()); - let (control, connection) = crate::yamux::Control::new(connection); - - let address = match address { - AddressType::Socket(address) => Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - AddressType::Dns(address, port) => Multiaddr::empty() - .with(Protocol::Dns(Cow::Owned(address))) - .with(Protocol::Tcp(port)), - }; - let endpoint = match role { - Role::Dialer => Endpoint::dialer(address, connection_id), - Role::Listener => Endpoint::listener(address, connection_id), - }; - - Ok(NegotiatedConnection { peer, control, connection, endpoint, substream_open_timeout }) - } - - /// Start connection event loop. - pub(crate) async fn start(mut self) -> crate::Result<()> { - self.protocol_set - .report_connection_established(self.peer, self.endpoint.clone()) - .await?; - - loop { - tokio::select! { - substream = self.connection.next() => match substream { - Some(Ok(stream)) => { - let substream_id = { - let substream_id = self.next_substream_id.fetch_add(1usize, Ordering::Relaxed); - SubstreamId::from(substream_id) - }; - let protocols = self.protocol_set.protocols(); - let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - let open_timeout = self.substream_open_timeout; - - self.pending_substreams.push(Box::pin(async move { - match tokio::time::timeout( - open_timeout, - Self::accept_substream(stream, permit, substream_id, protocols, open_timeout), - ) - .await - { - Ok(Ok(substream)) => Ok(substream), - Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { - protocol: None, - substream_id: None, - error, - }), - Err(_) => Err(ConnectionError::Timeout { - protocol: None, - substream_id: None - }), - } - })); - }, - Some(Err(error)) => { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - ?error, - "connection closed with error", - ); - self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await?; - - return Ok(()) - } - None => { - tracing::debug!(target: LOG_TARGET, peer = ?self.peer, "connection closed"); - self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await?; - - return Ok(()) - } - }, - // TODO: move this to a function - substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => { - match substream { - // TODO: return error to protocol - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to accept/open substream", - ); - - let (protocol, substream_id, error) = match error { - ConnectionError::Timeout { protocol, substream_id } => { - (protocol, substream_id, Error::Timeout) - } - ConnectionError::FailedToNegotiate { protocol, substream_id, error } => { - (protocol, substream_id, error) - } - }; - - match (protocol, substream_id) { - (Some(protocol), Some(substream_id)) => { - if let Err(error) = self.protocol_set - .report_substream_open_failure(protocol, substream_id, error) - .await - { - tracing::error!( - target: LOG_TARGET, - ?error, - "failed to register opened substream to protocol" - ); - } - } - _ => {} - } - } - Ok(substream) => { - let protocol = substream.protocol.clone(); - let direction = substream.direction; - let substream_id = substream.substream_id; - let socket = FuturesAsyncReadCompatExt::compat(substream.io); - let bandwidth_sink = self.bandwidth_sink.clone(); - - let substream = substream::Substream::new_tcp( - self.peer, - substream_id, - Substream::new(socket, bandwidth_sink, substream.permit), - self.protocol_set.protocol_codec(&protocol) - ); - - if let Err(error) = self.protocol_set - .report_substream_open(self.peer, protocol, direction, substream) - .await - { - tracing::error!( - target: LOG_TARGET, - ?error, - "failed to register opened substream to protocol", - ); - } - } - } - } - protocol = self.protocol_set.next() => match protocol { - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { - let control = self.control.clone(); - let open_timeout = self.substream_open_timeout; - - tracing::trace!( - target: LOG_TARGET, - ?protocol, - ?substream_id, - "open substream", - ); - - self.pending_substreams.push(Box::pin(async move { - match tokio::time::timeout( - open_timeout, - Self::open_substream( - control, - substream_id, - permit, - protocol.clone(), - fallback_names, - open_timeout, - ), - ) - .await - { - Ok(Ok(substream)) => Ok(substream), - Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { - protocol: Some(protocol), - substream_id: Some(substream_id), - error, - }), - Err(_) => Err(ConnectionError::Timeout { - protocol: Some(protocol), - substream_id: Some(substream_id) - }), - } - })); - } - Some(ProtocolCommand::ForceClose) => { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - connection_id = ?self.endpoint.connection_id(), - "force closing connection", - ); - - return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await - } - None => { - tracing::debug!(target: LOG_TARGET, "protocols have disconnected, closing connection"); - return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await - } - } - } - } - } + /// Create new [`TcpConnection`] from [`NegotiatedConnection`]. + pub(super) fn new( + context: NegotiatedConnection, + protocol_set: ProtocolSet, + bandwidth_sink: BandwidthSink, + next_substream_id: Arc, + ) -> Self { + let NegotiatedConnection { + connection, + control, + peer, + endpoint, + substream_open_timeout, + } = context; + + Self { + protocol_set, + connection, + control, + peer, + endpoint, + bandwidth_sink, + next_substream_id, + pending_substreams: FuturesUnordered::new(), + substream_open_timeout, + } + } + + /// Open connection to remote peer at `address`. + // TODO: this function can be removed + pub(super) async fn open_connection( + connection_id: ConnectionId, + keypair: Keypair, + stream: TcpStream, + address: AddressType, + peer: Option, + yamux_config: crate::yamux::Config, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + connection_open_timeout: Duration, + substream_open_timeout: Duration, + ) -> crate::Result { + tracing::debug!( + target: LOG_TARGET, + ?address, + ?peer, + "open connection to remote peer", + ); + + match tokio::time::timeout(connection_open_timeout, async move { + Self::negotiate_connection( + stream, + peer, + connection_id, + keypair, + Role::Dialer, + address, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + substream_open_timeout, + ) + .await + }) + .await + { + Err(_) => Err(Error::Timeout), + Ok(result) => result, + } + } + + /// Open substream for `protocol`. + pub(super) async fn open_substream( + mut control: crate::yamux::Control, + substream_id: SubstreamId, + permit: Permit, + protocol: ProtocolName, + fallback_names: Vec, + open_timeout: Duration, + ) -> crate::Result { + tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); + + let stream = match control.open_stream().await { + Ok(stream) => { + tracing::trace!(target: LOG_TARGET, ?substream_id, "substream opened"); + stream + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?substream_id, + ?error, + "failed to open substream" + ); + return Err(Error::YamuxError(Direction::Outbound(substream_id), error)); + } + }; + + // TODO: protocols don't change after they've been initialized so this should be done only + // once + let protocols = std::iter::once(&*protocol) + .chain(fallback_names.iter().map(|protocol| &**protocol)) + .collect(); + + let (io, protocol) = + Self::negotiate_protocol(stream, &Role::Dialer, protocols, open_timeout).await?; + + Ok(NegotiatedSubstream { + io: io.inner(), + substream_id, + direction: Direction::Outbound(substream_id), + protocol, + permit, + }) + } + + /// Accept a new connection. + pub(super) async fn accept_connection( + stream: TcpStream, + connection_id: ConnectionId, + keypair: Keypair, + address: SocketAddr, + yamux_config: crate::yamux::Config, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + connection_open_timeout: Duration, + substream_open_timeout: Duration, + ) -> crate::Result { + tracing::debug!(target: LOG_TARGET, ?address, "accept connection"); + + match tokio::time::timeout(connection_open_timeout, async move { + Self::negotiate_connection( + stream, + None, + connection_id, + keypair, + Role::Listener, + AddressType::Socket(address), + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + substream_open_timeout, + ) + .await + }) + .await + { + Err(_) => return Err(Error::Timeout), + Ok(result) => result, + } + } + + /// Accept substream. + pub(super) async fn accept_substream( + stream: crate::yamux::Stream, + permit: Permit, + substream_id: SubstreamId, + protocols: Vec, + open_timeout: Duration, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + "accept inbound substream", + ); + + let protocols = protocols.iter().map(|protocol| &**protocol).collect::>(); + let (io, protocol) = + Self::negotiate_protocol(stream, &Role::Listener, protocols, open_timeout).await?; + + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + "substream accepted and negotiated", + ); + + Ok(NegotiatedSubstream { + io: io.inner(), + substream_id, + direction: Direction::Inbound, + protocol, + permit, + }) + } + + /// Negotiate protocol. + async fn negotiate_protocol( + stream: S, + role: &Role, + protocols: Vec<&str>, + substream_open_timeout: Duration, + ) -> crate::Result<(Negotiated, ProtocolName)> { + tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); + + match tokio::time::timeout(substream_open_timeout, async move { + match role { + Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await, + Role::Listener => listener_select_proto(stream, protocols).await, + } + }) + .await + { + Err(_) => Err(Error::Timeout), + Ok(Err(error)) => Err(Error::NegotiationError( + NegotiationError::MultistreamSelectError(error), + )), + Ok(Ok((protocol, socket))) => { + tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); + + Ok((socket, ProtocolName::from(protocol.to_string()))) + } + } + } + + /// Negotiate noise + yamux for the connection. + pub(super) async fn negotiate_connection( + stream: TcpStream, + dialed_peer: Option, + connection_id: ConnectionId, + keypair: Keypair, + role: Role, + address: AddressType, + yamux_config: crate::yamux::Config, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + substream_open_timeout: Duration, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?role, + "negotiate connection", + ); + + let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); + let stream = TokioAsyncWriteCompatExt::compat_write(stream); + + // negotiate `noise` + let (stream, _) = + Self::negotiate_protocol(stream, &role, vec!["/noise"], substream_open_timeout).await?; + + tracing::trace!( + target: LOG_TARGET, + "`multistream-select` and `noise` negotiated", + ); + + // perform noise handshake + let (stream, peer) = noise::handshake( + stream.inner(), + &keypair, + role, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await?; + + if let Some(dialed_peer) = dialed_peer { + if dialed_peer != peer { + tracing::debug!(target: LOG_TARGET, ?dialed_peer, ?peer, "peer id mismatch"); + return Err(Error::PeerIdMismatch(dialed_peer, peer)); + } + } + + tracing::trace!(target: LOG_TARGET, "noise handshake done"); + let stream: NoiseSocket> = stream; + + // negotiate `yamux` + let (stream, _) = + Self::negotiate_protocol(stream, &role, vec!["/yamux/1.0.0"], substream_open_timeout) + .await?; + tracing::trace!(target: LOG_TARGET, "`yamux` negotiated"); + + let connection = crate::yamux::Connection::new(stream.inner(), yamux_config, role.into()); + let (control, connection) = crate::yamux::Control::new(connection); + + let address = match address { + AddressType::Socket(address) => Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + AddressType::Dns(address, port) => Multiaddr::empty() + .with(Protocol::Dns(Cow::Owned(address))) + .with(Protocol::Tcp(port)), + }; + let endpoint = match role { + Role::Dialer => Endpoint::dialer(address, connection_id), + Role::Listener => Endpoint::listener(address, connection_id), + }; + + Ok(NegotiatedConnection { + peer, + control, + connection, + endpoint, + substream_open_timeout, + }) + } + + /// Start connection event loop. + pub(crate) async fn start(mut self) -> crate::Result<()> { + self.protocol_set + .report_connection_established(self.peer, self.endpoint.clone()) + .await?; + + loop { + tokio::select! { + substream = self.connection.next() => match substream { + Some(Ok(stream)) => { + let substream_id = { + let substream_id = self.next_substream_id.fetch_add(1usize, Ordering::Relaxed); + SubstreamId::from(substream_id) + }; + let protocols = self.protocol_set.protocols(); + let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; + let open_timeout = self.substream_open_timeout; + + self.pending_substreams.push(Box::pin(async move { + match tokio::time::timeout( + open_timeout, + Self::accept_substream(stream, permit, substream_id, protocols, open_timeout), + ) + .await + { + Ok(Ok(substream)) => Ok(substream), + Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { + protocol: None, + substream_id: None, + error, + }), + Err(_) => Err(ConnectionError::Timeout { + protocol: None, + substream_id: None + }), + } + })); + }, + Some(Err(error)) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "connection closed with error", + ); + self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await?; + + return Ok(()) + } + None => { + tracing::debug!(target: LOG_TARGET, peer = ?self.peer, "connection closed"); + self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await?; + + return Ok(()) + } + }, + // TODO: move this to a function + substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => { + match substream { + // TODO: return error to protocol + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + "failed to accept/open substream", + ); + + let (protocol, substream_id, error) = match error { + ConnectionError::Timeout { protocol, substream_id } => { + (protocol, substream_id, Error::Timeout) + } + ConnectionError::FailedToNegotiate { protocol, substream_id, error } => { + (protocol, substream_id, error) + } + }; + + match (protocol, substream_id) { + (Some(protocol), Some(substream_id)) => { + if let Err(error) = self.protocol_set + .report_substream_open_failure(protocol, substream_id, error) + .await + { + tracing::error!( + target: LOG_TARGET, + ?error, + "failed to register opened substream to protocol" + ); + } + } + _ => {} + } + } + Ok(substream) => { + let protocol = substream.protocol.clone(); + let direction = substream.direction; + let substream_id = substream.substream_id; + let socket = FuturesAsyncReadCompatExt::compat(substream.io); + let bandwidth_sink = self.bandwidth_sink.clone(); + + let substream = substream::Substream::new_tcp( + self.peer, + substream_id, + Substream::new(socket, bandwidth_sink, substream.permit), + self.protocol_set.protocol_codec(&protocol) + ); + + if let Err(error) = self.protocol_set + .report_substream_open(self.peer, protocol, direction, substream) + .await + { + tracing::error!( + target: LOG_TARGET, + ?error, + "failed to register opened substream to protocol", + ); + } + } + } + } + protocol = self.protocol_set.next() => match protocol { + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + let control = self.control.clone(); + let open_timeout = self.substream_open_timeout; + + tracing::trace!( + target: LOG_TARGET, + ?protocol, + ?substream_id, + "open substream", + ); + + self.pending_substreams.push(Box::pin(async move { + match tokio::time::timeout( + open_timeout, + Self::open_substream( + control, + substream_id, + permit, + protocol.clone(), + fallback_names, + open_timeout, + ), + ) + .await + { + Ok(Ok(substream)) => Ok(substream), + Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { + protocol: Some(protocol), + substream_id: Some(substream_id), + error, + }), + Err(_) => Err(ConnectionError::Timeout { + protocol: Some(protocol), + substream_id: Some(substream_id) + }), + } + })); + } + Some(ProtocolCommand::ForceClose) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + connection_id = ?self.endpoint.connection_id(), + "force closing connection", + ); + + return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await + } + None => { + tracing::debug!(target: LOG_TARGET, "protocols have disconnected, closing connection"); + return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await + } + } + } + } + } } #[cfg(test)] mod tests { - use crate::transport::tcp::TcpTransport; - - use super::*; - use tokio::{io::AsyncWriteExt, net::TcpListener}; - - #[tokio::test] - async fn multistream_select_not_supported_dialer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (mut stream, _) = listener.accept().await.unwrap(); - let _ = stream.write_all(&vec![0x12u8; 256]).await; - }); - - let (_, stream) = TcpTransport::dial_peer( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - Default::default(), - Duration::from_secs(10), - ) - .await - .unwrap(); - - match TcpConnection::open_connection( - ConnectionId::from(0usize), - Keypair::generate(), - stream, - AddressType::Socket(address), - None, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( - crate::multistream_select::NegotiationError::ProtocolError( - crate::multistream_select::ProtocolError::InvalidMessage, - ), - ))) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn multistream_select_not_supported_listener() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let (Ok(mut dialer), Ok((stream, dialer_address))) = - tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) - else { - panic!("failed to establish connection"); - }; - - tokio::spawn(async move { - let _ = dialer.write_all(&vec![0x12u8; 256]).await; - }); - - match TcpConnection::accept_connection( - stream, - ConnectionId::from(0usize), - Keypair::generate(), - dialer_address, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( - crate::multistream_select::NegotiationError::ProtocolError( - crate::multistream_select::ProtocolError::InvalidMessage, - ), - ))) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn noise_not_supported_dialer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); - let stream = TokioAsyncWriteCompatExt::compat_write(stream); - - // attempt to negotiate yamux, skipping noise entirely - assert!(listener_select_proto(stream, vec!["/yamux/1.0.0"]).await.is_err()); - }); - - let (_, stream) = TcpTransport::dial_peer( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - Default::default(), - Duration::from_secs(10), - ) - .await - .unwrap(); - - match TcpConnection::open_connection( - ConnectionId::from(0usize), - Keypair::generate(), - stream, - AddressType::Socket(address), - None, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( - crate::multistream_select::NegotiationError::Failed, - ))) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn noise_not_supported_listener() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let (Ok(dialer), Ok((listener, dialer_address))) = - tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) - else { - panic!("failed to establish connection"); - }; - - tokio::spawn(async move { - let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); - let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); - - // attempt to negotiate yamux, skipping noise entirely - assert!(dialer_select_proto(dialer, vec!["/yamux/1.0.0"], Version::V1).await.is_err()); - }); - - match TcpConnection::accept_connection( - listener, - ConnectionId::from(0usize), - Keypair::generate(), - dialer_address, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( - crate::multistream_select::NegotiationError::Failed, - ))) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn noise_timeout_listener() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let (Ok(dialer), Ok((listener, dialer_address))) = - tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) - else { - panic!("failed to establish connection"); - }; - - tokio::spawn(async move { - let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); - let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); - - // attempt to negotiate yamux, skipping noise entirely - let (_protocol, _socket) = - dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - }); - - match TcpConnection::accept_connection( - listener, - ConnectionId::from(0usize), - Keypair::generate(), - dialer_address, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::Timeout) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn noise_timeout_dialer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); - let stream = TokioAsyncWriteCompatExt::compat_write(stream); - - // negotiate noise but never actually send any handshake data - let (_protocol, _socket) = listener_select_proto(stream, vec!["/noise"]).await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - }); - - let (_, stream) = TcpTransport::dial_peer( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - Default::default(), - Duration::from_secs(10), - ) - .await - .unwrap(); - - match TcpConnection::open_connection( - ConnectionId::from(0usize), - Keypair::generate(), - stream, - AddressType::Socket(address), - None, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::Timeout) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn multistream_select_timeout_dialer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let _stream = listener.accept().await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - }); - - let (_, stream) = TcpTransport::dial_peer( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - Default::default(), - Duration::from_secs(10), - ) - .await - .unwrap(); - - match TcpConnection::open_connection( - ConnectionId::from(0usize), - Keypair::generate(), - stream, - AddressType::Socket(address), - None, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::Timeout) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn multistream_select_timeout_listener() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let (Ok(_dialer), Ok((listener, dialer_address))) = - tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) - else { - panic!("failed to establish connection"); - }; - - tokio::spawn(async move { - let _stream = TcpStream::connect(address).await.unwrap(); - - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - }); - - match TcpConnection::accept_connection( - listener, - ConnectionId::from(0usize), - Keypair::generate(), - dialer_address, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::Timeout) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn yamux_not_supported_dialer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let (Ok(dialer), Ok((listener, dialer_address))) = - tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) - else { - panic!("failed to establish connection"); - }; - - tokio::spawn(async move { - let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); - let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); - - // negotiate noise - let (_protocol, stream) = - dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap(); - - let keypair = Keypair::generate(); - - // do a noise handshake - let (stream, _peer) = - noise::handshake(stream.inner(), &keypair, Role::Dialer, 5, 2).await.unwrap(); - let stream: NoiseSocket> = stream; - - // after the handshake, try to negotiate some random protocol instead of yamux - assert!(dialer_select_proto(stream, vec!["/unsupported/1"], Version::V1) - .await - .is_err()); - }); - - match TcpConnection::accept_connection( - listener, - ConnectionId::from(0usize), - Keypair::generate(), - dialer_address, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( - crate::multistream_select::NegotiationError::Failed, - ))) => {}, - Err(error) => panic!("{error:?}"), - } - } - - #[tokio::test] - async fn yamux_not_supported_listener() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); - let stream = TokioAsyncWriteCompatExt::compat_write(stream); - - // negotiate noise - let (_protocol, stream) = listener_select_proto(stream, vec!["/noise"]).await.unwrap(); - - // do a noise handshake - let keypair = Keypair::generate(); - let (stream, _peer) = - noise::handshake(stream.inner(), &keypair, Role::Listener, 5, 2).await.unwrap(); - let stream: NoiseSocket> = stream; - - // after the handshake, try to negotiate some random protocol instead of yamux - assert!(listener_select_proto(stream, vec!["/unsupported/1"]).await.is_err()); - }); - - let (_, stream) = TcpTransport::dial_peer( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - Default::default(), - Duration::from_secs(10), - ) - .await - .unwrap(); - - match TcpConnection::open_connection( - ConnectionId::from(0usize), - Keypair::generate(), - stream, - AddressType::Socket(address), - None, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( - crate::multistream_select::NegotiationError::Failed, - ))) => {}, - Err(error) => panic!("{error:?}"), - } - } - - #[tokio::test] - async fn yamux_timeout_dialer() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let (Ok(dialer), Ok((listener, dialer_address))) = - tokio::join!(TcpStream::connect(address.clone()), listener.accept()) - else { - panic!("failed to establish connection"); - }; - - tokio::spawn(async move { - let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); - let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); - - // negotiate noise - let (_protocol, stream) = - dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap(); - - // do a noise handshake - let keypair = Keypair::generate(); - let (stream, _peer) = - noise::handshake(stream.inner(), &keypair, Role::Dialer, 5, 2).await.unwrap(); - let _stream: NoiseSocket> = stream; - - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - }); - - match TcpConnection::accept_connection( - listener, - ConnectionId::from(0usize), - Keypair::generate(), - dialer_address, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::Timeout) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } - - #[tokio::test] - async fn yamux_timeout_listener() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - tokio::spawn(async move { - let (stream, _) = listener.accept().await.unwrap(); - let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); - let stream = TokioAsyncWriteCompatExt::compat_write(stream); - - // negotiate noise - let (_protocol, stream) = listener_select_proto(stream, vec!["/noise"]).await.unwrap(); - - // do a noise handshake - let keypair = Keypair::generate(); - let (stream, _peer) = - noise::handshake(stream.inner(), &keypair, Role::Listener, 5, 2).await.unwrap(); - let _stream: NoiseSocket> = stream; - - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - }); - - let (_, stream) = TcpTransport::dial_peer( - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())), - Default::default(), - Duration::from_secs(10), - ) - .await - .unwrap(); - - match TcpConnection::open_connection( - ConnectionId::from(0usize), - Keypair::generate(), - stream, - AddressType::Socket(address), - None, - Default::default(), - 5, - 2, - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - { - Ok(_) => panic!("connection was supposed to fail"), - Err(Error::Timeout) => {}, - Err(error) => panic!("invalid error: {error:?}"), - } - } + use crate::transport::tcp::TcpTransport; + + use super::*; + use tokio::{io::AsyncWriteExt, net::TcpListener}; + + #[tokio::test] + async fn multistream_select_not_supported_dialer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let _ = stream.write_all(&vec![0x12u8; 256]).await; + }); + + let (_, stream) = TcpTransport::dial_peer( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + Default::default(), + Duration::from_secs(10), + ) + .await + .unwrap(); + + match TcpConnection::open_connection( + ConnectionId::from(0usize), + Keypair::generate(), + stream, + AddressType::Socket(address), + None, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( + crate::multistream_select::NegotiationError::ProtocolError( + crate::multistream_select::ProtocolError::InvalidMessage, + ), + ))) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn multistream_select_not_supported_listener() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let (Ok(mut dialer), Ok((stream, dialer_address))) = + tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) + else { + panic!("failed to establish connection"); + }; + + tokio::spawn(async move { + let _ = dialer.write_all(&vec![0x12u8; 256]).await; + }); + + match TcpConnection::accept_connection( + stream, + ConnectionId::from(0usize), + Keypair::generate(), + dialer_address, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( + crate::multistream_select::NegotiationError::ProtocolError( + crate::multistream_select::ProtocolError::InvalidMessage, + ), + ))) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn noise_not_supported_dialer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); + let stream = TokioAsyncWriteCompatExt::compat_write(stream); + + // attempt to negotiate yamux, skipping noise entirely + assert!(listener_select_proto(stream, vec!["/yamux/1.0.0"]).await.is_err()); + }); + + let (_, stream) = TcpTransport::dial_peer( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + Default::default(), + Duration::from_secs(10), + ) + .await + .unwrap(); + + match TcpConnection::open_connection( + ConnectionId::from(0usize), + Keypair::generate(), + stream, + AddressType::Socket(address), + None, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( + crate::multistream_select::NegotiationError::Failed, + ))) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn noise_not_supported_listener() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let (Ok(dialer), Ok((listener, dialer_address))) = + tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) + else { + panic!("failed to establish connection"); + }; + + tokio::spawn(async move { + let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); + let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); + + // attempt to negotiate yamux, skipping noise entirely + assert!(dialer_select_proto(dialer, vec!["/yamux/1.0.0"], Version::V1).await.is_err()); + }); + + match TcpConnection::accept_connection( + listener, + ConnectionId::from(0usize), + Keypair::generate(), + dialer_address, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( + crate::multistream_select::NegotiationError::Failed, + ))) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn noise_timeout_listener() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let (Ok(dialer), Ok((listener, dialer_address))) = + tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) + else { + panic!("failed to establish connection"); + }; + + tokio::spawn(async move { + let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); + let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); + + // attempt to negotiate yamux, skipping noise entirely + let (_protocol, _socket) = + dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + match TcpConnection::accept_connection( + listener, + ConnectionId::from(0usize), + Keypair::generate(), + dialer_address, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::Timeout) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn noise_timeout_dialer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); + let stream = TokioAsyncWriteCompatExt::compat_write(stream); + + // negotiate noise but never actually send any handshake data + let (_protocol, _socket) = listener_select_proto(stream, vec!["/noise"]).await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + let (_, stream) = TcpTransport::dial_peer( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + Default::default(), + Duration::from_secs(10), + ) + .await + .unwrap(); + + match TcpConnection::open_connection( + ConnectionId::from(0usize), + Keypair::generate(), + stream, + AddressType::Socket(address), + None, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::Timeout) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn multistream_select_timeout_dialer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let _stream = listener.accept().await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + let (_, stream) = TcpTransport::dial_peer( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + Default::default(), + Duration::from_secs(10), + ) + .await + .unwrap(); + + match TcpConnection::open_connection( + ConnectionId::from(0usize), + Keypair::generate(), + stream, + AddressType::Socket(address), + None, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::Timeout) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn multistream_select_timeout_listener() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let (Ok(_dialer), Ok((listener, dialer_address))) = + tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) + else { + panic!("failed to establish connection"); + }; + + tokio::spawn(async move { + let _stream = TcpStream::connect(address).await.unwrap(); + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + match TcpConnection::accept_connection( + listener, + ConnectionId::from(0usize), + Keypair::generate(), + dialer_address, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::Timeout) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn yamux_not_supported_dialer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let (Ok(dialer), Ok((listener, dialer_address))) = + tokio::join!(TcpStream::connect(address.clone()), listener.accept(),) + else { + panic!("failed to establish connection"); + }; + + tokio::spawn(async move { + let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); + let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); + + // negotiate noise + let (_protocol, stream) = + dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap(); + + let keypair = Keypair::generate(); + + // do a noise handshake + let (stream, _peer) = + noise::handshake(stream.inner(), &keypair, Role::Dialer, 5, 2).await.unwrap(); + let stream: NoiseSocket> = stream; + + // after the handshake, try to negotiate some random protocol instead of yamux + assert!( + dialer_select_proto(stream, vec!["/unsupported/1"], Version::V1).await.is_err() + ); + }); + + match TcpConnection::accept_connection( + listener, + ConnectionId::from(0usize), + Keypair::generate(), + dialer_address, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( + crate::multistream_select::NegotiationError::Failed, + ))) => {} + Err(error) => panic!("{error:?}"), + } + } + + #[tokio::test] + async fn yamux_not_supported_listener() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); + let stream = TokioAsyncWriteCompatExt::compat_write(stream); + + // negotiate noise + let (_protocol, stream) = listener_select_proto(stream, vec!["/noise"]).await.unwrap(); + + // do a noise handshake + let keypair = Keypair::generate(); + let (stream, _peer) = + noise::handshake(stream.inner(), &keypair, Role::Listener, 5, 2).await.unwrap(); + let stream: NoiseSocket> = stream; + + // after the handshake, try to negotiate some random protocol instead of yamux + assert!(listener_select_proto(stream, vec!["/unsupported/1"]).await.is_err()); + }); + + let (_, stream) = TcpTransport::dial_peer( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + Default::default(), + Duration::from_secs(10), + ) + .await + .unwrap(); + + match TcpConnection::open_connection( + ConnectionId::from(0usize), + Keypair::generate(), + stream, + AddressType::Socket(address), + None, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::NegotiationError(NegotiationError::MultistreamSelectError( + crate::multistream_select::NegotiationError::Failed, + ))) => {} + Err(error) => panic!("{error:?}"), + } + } + + #[tokio::test] + async fn yamux_timeout_dialer() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let (Ok(dialer), Ok((listener, dialer_address))) = + tokio::join!(TcpStream::connect(address.clone()), listener.accept()) + else { + panic!("failed to establish connection"); + }; + + tokio::spawn(async move { + let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner(); + let dialer = TokioAsyncWriteCompatExt::compat_write(dialer); + + // negotiate noise + let (_protocol, stream) = + dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap(); + + // do a noise handshake + let keypair = Keypair::generate(); + let (stream, _peer) = + noise::handshake(stream.inner(), &keypair, Role::Dialer, 5, 2).await.unwrap(); + let _stream: NoiseSocket> = stream; + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + match TcpConnection::accept_connection( + listener, + ConnectionId::from(0usize), + Keypair::generate(), + dialer_address, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::Timeout) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } + + #[tokio::test] + async fn yamux_timeout_listener() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let stream = TokioAsyncReadCompatExt::compat(stream).into_inner(); + let stream = TokioAsyncWriteCompatExt::compat_write(stream); + + // negotiate noise + let (_protocol, stream) = listener_select_proto(stream, vec!["/noise"]).await.unwrap(); + + // do a noise handshake + let keypair = Keypair::generate(); + let (stream, _peer) = + noise::handshake(stream.inner(), &keypair, Role::Listener, 5, 2).await.unwrap(); + let _stream: NoiseSocket> = stream; + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + }); + + let (_, stream) = TcpTransport::dial_peer( + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())), + Default::default(), + Duration::from_secs(10), + ) + .await + .unwrap(); + + match TcpConnection::open_connection( + ConnectionId::from(0usize), + Keypair::generate(), + stream, + AddressType::Socket(address), + None, + Default::default(), + 5, + 2, + Duration::from_secs(10), + Duration::from_secs(10), + ) + .await + { + Ok(_) => panic!("connection was supposed to fail"), + Err(Error::Timeout) => {} + Err(error) => panic!("invalid error: {error:?}"), + } + } } diff --git a/src/transport/tcp/listener.rs b/src/transport/tcp/listener.rs index 929811f4..2a1ddb72 100644 --- a/src/transport/tcp/listener.rs +++ b/src/transport/tcp/listener.rs @@ -30,11 +30,11 @@ use socket2::{Domain, Socket, Type}; use tokio::net::{TcpListener as TokioTcpListener, TcpStream}; use std::{ - io, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; /// Logging target for the file. @@ -43,381 +43,391 @@ const LOG_TARGET: &str = "litep2p::tcp::listener"; /// Address type. #[derive(Debug)] pub(super) enum AddressType { - /// Socket address. - Socket(SocketAddr), + /// Socket address. + Socket(SocketAddr), - /// DNS address. - Dns(String, u16), + /// DNS address. + Dns(String, u16), } /// TCP listener listening to zero or more addresses. pub struct TcpListener { - /// Listeners. - listeners: Vec, + /// Listeners. + listeners: Vec, } /// Local addresses to use for outbound connections. #[derive(Clone)] pub enum DialAddresses { - /// Reuse port from listen addresses. - Reuse { listen_addresses: Arc> }, - /// Do not reuse port. - NoReuse, + /// Reuse port from listen addresses. + Reuse { + listen_addresses: Arc>, + }, + /// Do not reuse port. + NoReuse, } impl Default for DialAddresses { - fn default() -> Self { - DialAddresses::NoReuse - } + fn default() -> Self { + DialAddresses::NoReuse + } } impl DialAddresses { - /// Get local dial address for an outbound connection. - pub(super) fn local_dial_address( - &self, - remote_address: &IpAddr, - ) -> Result, ()> { - match self { - DialAddresses::Reuse { listen_addresses } => { - for address in listen_addresses.iter() { - if remote_address.is_ipv4() == address.is_ipv4() && - remote_address.is_loopback() == address.ip().is_loopback() - { - if remote_address.is_ipv4() { - return Ok(Some(SocketAddr::new( - IpAddr::V4(Ipv4Addr::UNSPECIFIED), - address.port(), - ))); - } else { - return Ok(Some(SocketAddr::new( - IpAddr::V6(Ipv6Addr::UNSPECIFIED), - address.port(), - ))); - } - } - } - - Err(()) - }, - DialAddresses::NoReuse => Ok(None), - } - } + /// Get local dial address for an outbound connection. + pub(super) fn local_dial_address( + &self, + remote_address: &IpAddr, + ) -> Result, ()> { + match self { + DialAddresses::Reuse { listen_addresses } => { + for address in listen_addresses.iter() { + if remote_address.is_ipv4() == address.is_ipv4() + && remote_address.is_loopback() == address.ip().is_loopback() + { + if remote_address.is_ipv4() { + return Ok(Some(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + address.port(), + ))); + } else { + return Ok(Some(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + address.port(), + ))); + } + } + } + + Err(()) + } + DialAddresses::NoReuse => Ok(None), + } + } } impl TcpListener { - /// Create new [`TcpListener`] - pub fn new( - addresses: Vec, - reuse_port: bool, - ) -> (Self, Vec, DialAddresses) { - let (listeners, listen_addresses): (_, Vec>) = addresses - .into_iter() - .filter_map(|address| { - let (socket, address) = match Self::get_socket_address(&address).ok()?.0 { - AddressType::Dns(_, _) => return None, - AddressType::Socket(address) => match address.is_ipv4() { - false => { - let socket = Socket::new( - Domain::IPV6, - Type::STREAM, - Some(socket2::Protocol::TCP), - ) - .ok()?; - socket.set_only_v6(true).ok()?; - (socket, address) - }, - true => ( - Socket::new(Domain::IPV4, Type::STREAM, Some(socket2::Protocol::TCP)) - .ok()?, - address, - ), - }, - }; - - socket.set_nonblocking(true).ok()?; - socket.set_reuse_address(true).ok()?; - #[cfg(unix)] - if reuse_port { - socket.set_reuse_port(true).ok()?; - } - socket.bind(&address.into()).ok()?; - socket.listen(1024).ok()?; - - let socket: std::net::TcpListener = socket.into(); - let listener = TokioTcpListener::from_std(socket).ok()?; - let local_address = listener.local_addr().ok()?; - - let listen_addresses = match address.ip().is_unspecified() { - true => match NetworkInterface::show() { - Ok(ifaces) => ifaces - .into_iter() - .flat_map(|record| { - record.addr.into_iter().filter_map(|iface_address| { - match (iface_address, address.is_ipv4()) { - (Addr::V4(inner), true) => Some(SocketAddr::new( - IpAddr::V4(inner.ip), - local_address.port(), - )), - (Addr::V6(inner), false) => - match inner.ip.segments().get(0) { - Some(0xfe80) => None, - _ => Some(SocketAddr::new( - IpAddr::V6(inner.ip), - local_address.port(), - )), - }, - _ => None, - } - }) - }) - .collect(), - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?error, - "failed to fetch network interfaces", - ); - - return None; - }, - }, - false => vec![local_address], - }; - - Some((listener, listen_addresses)) - }) - .unzip(); - - let listen_addresses = listen_addresses.into_iter().flatten().collect::>(); - let listen_multi_addresses = listen_addresses - .iter() - .cloned() - .map(|address| { - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - }) - .collect(); - let dial_addresses = if reuse_port { - DialAddresses::Reuse { listen_addresses: Arc::new(listen_addresses) } - } else { - DialAddresses::NoReuse - }; - - (Self { listeners }, listen_multi_addresses, dial_addresses) - } - - /// Extract socket address and `PeerId`, if found, from `address`. - pub(super) fn get_socket_address( - address: &Multiaddr, - ) -> crate::Result<(AddressType, Option)> { - tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); - - let mut iter = address.iter(); - let socket_address = match iter.next() { - Some(Protocol::Ip6(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => - AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Tcp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - Some(Protocol::Ip4(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => - AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Tcp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - Some(Protocol::Dns(address)) | - Some(Protocol::Dns4(address)) | - Some(Protocol::Dns6(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Tcp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - protocol => { - tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - let maybe_peer = match iter.next() { - Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), - None => None, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `P2p` or `None`" - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - Ok((socket_address, maybe_peer)) - } + /// Create new [`TcpListener`] + pub fn new( + addresses: Vec, + reuse_port: bool, + ) -> (Self, Vec, DialAddresses) { + let (listeners, listen_addresses): (_, Vec>) = addresses + .into_iter() + .filter_map(|address| { + let (socket, address) = match Self::get_socket_address(&address).ok()?.0 { + AddressType::Dns(_, _) => return None, + AddressType::Socket(address) => match address.is_ipv4() { + false => { + let socket = Socket::new( + Domain::IPV6, + Type::STREAM, + Some(socket2::Protocol::TCP), + ) + .ok()?; + socket.set_only_v6(true).ok()?; + (socket, address) + } + true => ( + Socket::new(Domain::IPV4, Type::STREAM, Some(socket2::Protocol::TCP)) + .ok()?, + address, + ), + }, + }; + + socket.set_nonblocking(true).ok()?; + socket.set_reuse_address(true).ok()?; + #[cfg(unix)] + if reuse_port { + socket.set_reuse_port(true).ok()?; + } + socket.bind(&address.into()).ok()?; + socket.listen(1024).ok()?; + + let socket: std::net::TcpListener = socket.into(); + let listener = TokioTcpListener::from_std(socket).ok()?; + let local_address = listener.local_addr().ok()?; + + let listen_addresses = match address.ip().is_unspecified() { + true => match NetworkInterface::show() { + Ok(ifaces) => ifaces + .into_iter() + .flat_map(|record| { + record.addr.into_iter().filter_map(|iface_address| { + match (iface_address, address.is_ipv4()) { + (Addr::V4(inner), true) => Some(SocketAddr::new( + IpAddr::V4(inner.ip), + local_address.port(), + )), + (Addr::V6(inner), false) => + match inner.ip.segments().get(0) { + Some(0xfe80) => None, + _ => Some(SocketAddr::new( + IpAddr::V6(inner.ip), + local_address.port(), + )), + }, + _ => None, + } + }) + }) + .collect(), + Err(error) => { + tracing::warn!( + target: LOG_TARGET, + ?error, + "failed to fetch network interfaces", + ); + + return None; + } + }, + false => vec![local_address], + }; + + Some((listener, listen_addresses)) + }) + .unzip(); + + let listen_addresses = listen_addresses.into_iter().flatten().collect::>(); + let listen_multi_addresses = listen_addresses + .iter() + .cloned() + .map(|address| { + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + }) + .collect(); + let dial_addresses = if reuse_port { + DialAddresses::Reuse { + listen_addresses: Arc::new(listen_addresses), + } + } else { + DialAddresses::NoReuse + }; + + (Self { listeners }, listen_multi_addresses, dial_addresses) + } + + /// Extract socket address and `PeerId`, if found, from `address`. + pub(super) fn get_socket_address( + address: &Multiaddr, + ) -> crate::Result<(AddressType, Option)> { + tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); + + let mut iter = address.iter(); + let socket_address = match iter.next() { + Some(Protocol::Ip6(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => + AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Tcp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + Some(Protocol::Ip4(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => + AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Tcp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + Some(Protocol::Dns(address)) + | Some(Protocol::Dns4(address)) + | Some(Protocol::Dns6(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Tcp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + protocol => { + tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + let maybe_peer = match iter.next() { + Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), + None => None, + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `P2p` or `None`" + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + Ok((socket_address, maybe_peer)) + } } impl Stream for TcpListener { - type Item = io::Result<(TcpStream, SocketAddr)>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.listeners.is_empty() { - return Poll::Pending; - } - - // TODO: make this more fair - for listener in self.listeners.iter_mut() { - match listener.poll_accept(cx) { - Poll::Pending => {}, - Poll::Ready(Err(error)) => return Poll::Ready(Some(Err(error))), - Poll::Ready(Ok((stream, address))) => - return Poll::Ready(Some(Ok((stream, address)))), - } - } - - Poll::Pending - } + type Item = io::Result<(TcpStream, SocketAddr)>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.listeners.is_empty() { + return Poll::Pending; + } + + // TODO: make this more fair + for listener in self.listeners.iter_mut() { + match listener.poll_accept(cx) { + Poll::Pending => {} + Poll::Ready(Err(error)) => return Poll::Ready(Some(Err(error))), + Poll::Ready(Ok((stream, address))) => + return Poll::Ready(Some(Ok((stream, address)))), + } + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use futures::StreamExt; - - #[test] - fn parse_multiaddresses() { - assert!(TcpListener::get_socket_address( - &"/ip6/::1/tcp/8888".parse().expect("valid multiaddress") - ) - .is_ok()); - assert!(TcpListener::get_socket_address( - &"/ip4/127.0.0.1/tcp/8888".parse().expect("valid multiaddress") - ) - .is_ok()); - assert!(TcpListener::get_socket_address( - &"/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_ok()); - assert!(TcpListener::get_socket_address( - &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_ok()); - assert!(TcpListener::get_socket_address( - &"/ip6/::1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(TcpListener::get_socket_address( - &"/ip4/127.0.0.1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - } - - #[tokio::test] - async fn no_listeners() { - let (mut listener, _, _) = TcpListener::new(Vec::new(), true); - - futures::future::poll_fn(|cx| match listener.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("unexpected event: {event:?}"), - }) - .await; - } - - #[tokio::test] - async fn one_listener() { - let address: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap(); - let (mut listener, listen_addresses, _) = TcpListener::new(vec![address.clone()], true); - let Some(Protocol::Tcp(port)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let (res1, res2) = - tokio::join!(listener.next(), TcpStream::connect(format!("[::1]:{port}"))); - - assert!(res1.unwrap().is_ok() && res2.is_ok()); - } - - #[tokio::test] - async fn two_listeners() { - let address1: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap(); - let address2: Multiaddr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); - let (mut listener, listen_addresses, _) = TcpListener::new(vec![address1, address2], true); - let Some(Protocol::Tcp(port1)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let Some(Protocol::Tcp(port2)) = - listen_addresses.iter().skip(1).next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - tokio::spawn(async move { while let Some(_) = listener.next().await {} }); - - let (res1, res2) = tokio::join!( - TcpStream::connect(format!("[::1]:{port1}")), - TcpStream::connect(format!("127.0.0.1:{port2}")) - ); - - assert!(res1.is_ok() && res2.is_ok()); - } - - #[tokio::test] - async fn local_dial_address() { - let dial_addresses = DialAddresses::Reuse { - listen_addresses: Arc::new(vec![ - "[2001:7d0:84aa:3900:2a5d:9e85::]:8888".parse().unwrap(), - "92.168.127.1:9999".parse().unwrap(), - ]), - }; - - assert_eq!( - dial_addresses.local_dial_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))), - Ok(Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 9999))), - ); - - assert_eq!( - dial_addresses.local_dial_address(&IpAddr::V6(Ipv6Addr::new(0, 1, 2, 3, 4, 5, 6, 7))), - Ok(Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 8888))), - ); - } - - #[tokio::test] - async fn show_all_addresses() { - let address1: Multiaddr = "/ip6/::/tcp/0".parse().unwrap(); - let address2: Multiaddr = "/ip4/0.0.0.0/tcp/0".parse().unwrap(); - let (_, listen_addresses, _) = TcpListener::new(vec![address1, address2], true); - - println!("{listen_addresses:#?}"); - } + use super::*; + use futures::StreamExt; + + #[test] + fn parse_multiaddresses() { + assert!(TcpListener::get_socket_address( + &"/ip6/::1/tcp/8888".parse().expect("valid multiaddress") + ) + .is_ok()); + assert!(TcpListener::get_socket_address( + &"/ip4/127.0.0.1/tcp/8888".parse().expect("valid multiaddress") + ) + .is_ok()); + assert!(TcpListener::get_socket_address( + &"/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_ok()); + assert!(TcpListener::get_socket_address( + &"/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_ok()); + assert!(TcpListener::get_socket_address( + &"/ip6/::1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(TcpListener::get_socket_address( + &"/ip4/127.0.0.1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + } + + #[tokio::test] + async fn no_listeners() { + let (mut listener, _, _) = TcpListener::new(Vec::new(), true); + + futures::future::poll_fn(|cx| match listener.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("unexpected event: {event:?}"), + }) + .await; + } + + #[tokio::test] + async fn one_listener() { + let address: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap(); + let (mut listener, listen_addresses, _) = TcpListener::new(vec![address.clone()], true); + let Some(Protocol::Tcp(port)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let (res1, res2) = + tokio::join!(listener.next(), TcpStream::connect(format!("[::1]:{port}"))); + + assert!(res1.unwrap().is_ok() && res2.is_ok()); + } + + #[tokio::test] + async fn two_listeners() { + let address1: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap(); + let address2: Multiaddr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); + let (mut listener, listen_addresses, _) = TcpListener::new(vec![address1, address2], true); + let Some(Protocol::Tcp(port1)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let Some(Protocol::Tcp(port2)) = + listen_addresses.iter().skip(1).next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + tokio::spawn(async move { while let Some(_) = listener.next().await {} }); + + let (res1, res2) = tokio::join!( + TcpStream::connect(format!("[::1]:{port1}")), + TcpStream::connect(format!("127.0.0.1:{port2}")) + ); + + assert!(res1.is_ok() && res2.is_ok()); + } + + #[tokio::test] + async fn local_dial_address() { + let dial_addresses = DialAddresses::Reuse { + listen_addresses: Arc::new(vec![ + "[2001:7d0:84aa:3900:2a5d:9e85::]:8888".parse().unwrap(), + "92.168.127.1:9999".parse().unwrap(), + ]), + }; + + assert_eq!( + dial_addresses.local_dial_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))), + Ok(Some(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + 9999 + ))), + ); + + assert_eq!( + dial_addresses.local_dial_address(&IpAddr::V6(Ipv6Addr::new(0, 1, 2, 3, 4, 5, 6, 7))), + Ok(Some(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + 8888 + ))), + ); + } + + #[tokio::test] + async fn show_all_addresses() { + let address1: Multiaddr = "/ip6/::/tcp/0".parse().unwrap(); + let address2: Multiaddr = "/ip4/0.0.0.0/tcp/0".parse().unwrap(); + let (_, listen_addresses, _) = TcpListener::new(vec![address1, address2], true); + + println!("{listen_addresses:#?}"); + } } diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 6eda8737..0936a946 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -22,38 +22,38 @@ //! TCP transport. use crate::{ - config::Role, - error::Error, - transport::{ - manager::TransportHandle, - tcp::{ - config::Config, - connection::{NegotiatedConnection, TcpConnection}, - listener::{AddressType, DialAddresses, TcpListener}, - }, - Transport, TransportBuilder, TransportEvent, - }, - types::ConnectionId, + config::Role, + error::Error, + transport::{ + manager::TransportHandle, + tcp::{ + config::Config, + connection::{NegotiatedConnection, TcpConnection}, + listener::{AddressType, DialAddresses, TcpListener}, + }, + Transport, TransportBuilder, TransportEvent, + }, + types::ConnectionId, }; use futures::{ - future::BoxFuture, - stream::{FuturesUnordered, Stream, StreamExt}, + future::BoxFuture, + stream::{FuturesUnordered, Stream, StreamExt}, }; use multiaddr::{Multiaddr, Protocol}; use socket2::{Domain, Socket, Type}; use tokio::net::TcpStream; use trust_dns_resolver::{ - config::{ResolverConfig, ResolverOpts}, - TokioAsyncResolver, + config::{ResolverConfig, ResolverOpts}, + TokioAsyncResolver, }; use std::{ - collections::{HashMap, HashSet}, - net::SocketAddr, - pin::Pin, - task::{Context, Poll}, - time::Duration, + collections::{HashMap, HashSet}, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, + time::Duration, }; pub(crate) use substream::Substream; @@ -69,680 +69,698 @@ const LOG_TARGET: &str = "litep2p::tcp"; /// TCP transport. pub(crate) struct TcpTransport { - /// Transport context. - context: TransportHandle, + /// Transport context. + context: TransportHandle, - /// Transport configuration. - config: Config, + /// Transport configuration. + config: Config, - /// TCP listener. - listener: TcpListener, + /// TCP listener. + listener: TcpListener, - /// Pending dials. - pending_dials: HashMap, + /// Pending dials. + pending_dials: HashMap, - /// Dial addresses. - dial_addresses: DialAddresses, + /// Dial addresses. + dial_addresses: DialAddresses, - /// Pending opening connections. - pending_connections: - FuturesUnordered>>, + /// Pending opening connections. + pending_connections: + FuturesUnordered>>, - /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture<'static, Result<(ConnectionId, Multiaddr, TcpStream), ConnectionId>>, - >, + /// Pending raw, unnegotiated connections. + pending_raw_connections: FuturesUnordered< + BoxFuture<'static, Result<(ConnectionId, Multiaddr, TcpStream), ConnectionId>>, + >, - /// Opened raw connection, waiting for approval/rejection from `TransportManager`. - opened_raw: HashMap, + /// Opened raw connection, waiting for approval/rejection from `TransportManager`. + opened_raw: HashMap, - /// Canceled raw connections. - canceled: HashSet, + /// Canceled raw connections. + canceled: HashSet, - /// Connections which have been opened and negotiated but are being validated by the - /// `TransportManager`. - pending_open: HashMap, + /// Connections which have been opened and negotiated but are being validated by the + /// `TransportManager`. + pending_open: HashMap, } impl TcpTransport { - /// Handle inbound TCP connection. - fn on_inbound_connection(&mut self, connection: TcpStream, address: SocketAddr) { - let connection_id = self.context.next_connection_id(); - let yamux_config = self.config.yamux_config.clone(); - let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; - let max_write_buffer_size = self.config.noise_write_buffer_size; - let connection_open_timeout = self.config.connection_open_timeout; - let substream_open_timeout = self.config.substream_open_timeout; - let keypair = self.context.keypair.clone(); - - self.pending_connections.push(Box::pin(async move { - TcpConnection::accept_connection( - connection, - connection_id, - keypair, - address, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - connection_open_timeout, - substream_open_timeout, - ) - .await - .map_err(|error| (connection_id, error)) - })); - } - - /// Dial remote peer - async fn dial_peer( - address: Multiaddr, - dial_addresses: DialAddresses, - connection_open_timeout: Duration, - ) -> crate::Result<(Multiaddr, TcpStream)> { - let (socket_address, _) = TcpListener::get_socket_address(&address)?; - let remote_address = match socket_address { - AddressType::Socket(address) => address, - AddressType::Dns(url, port) => { - let address = address.clone(); - let future = async move { - match TokioAsyncResolver::tokio( - ResolverConfig::default(), - ResolverOpts::default(), - ) - .lookup_ip(url.clone()) - .await - { - // TODO: ugly - Ok(lookup) => { - let mut iter = lookup.iter(); - while let Some(ip) = iter.next() { - match ( - address.iter().next().expect("protocol to exist"), - ip.is_ipv4(), - ) { - (Protocol::Dns(_), true) | - (Protocol::Dns4(_), true) | - (Protocol::Dns6(_), false) => { - tracing::trace!( - target: LOG_TARGET, - ?address, - ?ip, - "address resolved", - ); - - return Ok(SocketAddr::new(ip, port)); - }, - _ => {}, - } - } - - Err(Error::Unknown) - }, - Err(_) => Err(Error::Unknown), - } - }; - - match tokio::time::timeout(connection_open_timeout, future).await { - Err(_) => return Err(Error::Timeout), - Ok(Err(error)) => return Err(error), - Ok(Ok(address)) => address, - } - }, - }; - - let domain = match remote_address.is_ipv4() { - true => Domain::IPV4, - false => Domain::IPV6, - }; - let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?; - if remote_address.is_ipv6() { - socket.set_only_v6(true)?; - } - socket.set_nonblocking(true)?; - - match dial_addresses.local_dial_address(&remote_address.ip()) { - Ok(Some(dial_address)) => { - socket.set_reuse_address(true)?; - #[cfg(unix)] - socket.set_reuse_port(true)?; - socket.bind(&dial_address.into())?; - }, - Ok(None) => {}, - Err(()) => { - tracing::debug!( - target: LOG_TARGET, - ?remote_address, - "tcp listener not enabled for remote address, using ephemeral port", - ); - }, - } - - let future = async move { - match socket.connect(&remote_address.into()) { - Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}, - Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}, - Err(err) => return Err(err.into()), - } - - let stream = TcpStream::try_from(Into::::into(socket))?; - stream.writable().await?; - - if let Some(e) = stream.take_error()? { - return Err(e); - } - - Ok((address, stream)) - }; - - match tokio::time::timeout(connection_open_timeout, future).await { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => Err(error.into()), - Ok(Ok((address, stream))) => Ok((address, stream)), - } - } + /// Handle inbound TCP connection. + fn on_inbound_connection(&mut self, connection: TcpStream, address: SocketAddr) { + let connection_id = self.context.next_connection_id(); + let yamux_config = self.config.yamux_config.clone(); + let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; + let max_write_buffer_size = self.config.noise_write_buffer_size; + let connection_open_timeout = self.config.connection_open_timeout; + let substream_open_timeout = self.config.substream_open_timeout; + let keypair = self.context.keypair.clone(); + + self.pending_connections.push(Box::pin(async move { + TcpConnection::accept_connection( + connection, + connection_id, + keypair, + address, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + connection_open_timeout, + substream_open_timeout, + ) + .await + .map_err(|error| (connection_id, error)) + })); + } + + /// Dial remote peer + async fn dial_peer( + address: Multiaddr, + dial_addresses: DialAddresses, + connection_open_timeout: Duration, + ) -> crate::Result<(Multiaddr, TcpStream)> { + let (socket_address, _) = TcpListener::get_socket_address(&address)?; + let remote_address = match socket_address { + AddressType::Socket(address) => address, + AddressType::Dns(url, port) => { + let address = address.clone(); + let future = async move { + match TokioAsyncResolver::tokio( + ResolverConfig::default(), + ResolverOpts::default(), + ) + .lookup_ip(url.clone()) + .await + { + // TODO: ugly + Ok(lookup) => { + let mut iter = lookup.iter(); + while let Some(ip) = iter.next() { + match ( + address.iter().next().expect("protocol to exist"), + ip.is_ipv4(), + ) { + (Protocol::Dns(_), true) + | (Protocol::Dns4(_), true) + | (Protocol::Dns6(_), false) => { + tracing::trace!( + target: LOG_TARGET, + ?address, + ?ip, + "address resolved", + ); + + return Ok(SocketAddr::new(ip, port)); + } + _ => {} + } + } + + Err(Error::Unknown) + } + Err(_) => Err(Error::Unknown), + } + }; + + match tokio::time::timeout(connection_open_timeout, future).await { + Err(_) => return Err(Error::Timeout), + Ok(Err(error)) => return Err(error), + Ok(Ok(address)) => address, + } + } + }; + + let domain = match remote_address.is_ipv4() { + true => Domain::IPV4, + false => Domain::IPV6, + }; + let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?; + if remote_address.is_ipv6() { + socket.set_only_v6(true)?; + } + socket.set_nonblocking(true)?; + + match dial_addresses.local_dial_address(&remote_address.ip()) { + Ok(Some(dial_address)) => { + socket.set_reuse_address(true)?; + #[cfg(unix)] + socket.set_reuse_port(true)?; + socket.bind(&dial_address.into())?; + } + Ok(None) => {} + Err(()) => { + tracing::debug!( + target: LOG_TARGET, + ?remote_address, + "tcp listener not enabled for remote address, using ephemeral port", + ); + } + } + + let future = async move { + match socket.connect(&remote_address.into()) { + Ok(()) => {} + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {} + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {} + Err(err) => return Err(err.into()), + } + + let stream = TcpStream::try_from(Into::::into(socket))?; + stream.writable().await?; + + if let Some(e) = stream.take_error()? { + return Err(e); + } + + Ok((address, stream)) + }; + + match tokio::time::timeout(connection_open_timeout, future).await { + Err(_) => Err(Error::Timeout), + Ok(Err(error)) => Err(error.into()), + Ok(Ok((address, stream))) => Ok((address, stream)), + } + } } impl TransportBuilder for TcpTransport { - type Config = Config; - type Transport = TcpTransport; - - /// Create new [`TcpTransport`]. - fn new( - context: TransportHandle, - mut config: Self::Config, - ) -> crate::Result<(Self, Vec)> { - tracing::debug!( - target: LOG_TARGET, - listen_addresses = ?config.listen_addresses, - "start tcp transport", - ); - - // start tcp listeners for all listen addresses - let (listener, listen_addresses, dial_addresses) = TcpListener::new( - std::mem::replace(&mut config.listen_addresses, Vec::new()), - config.reuse_port, - ); - - Ok(( - Self { - listener, - config, - context, - dial_addresses, - canceled: HashSet::new(), - opened_raw: HashMap::new(), - pending_open: HashMap::new(), - pending_dials: HashMap::new(), - pending_connections: FuturesUnordered::new(), - pending_raw_connections: FuturesUnordered::new(), - }, - listen_addresses, - )) - } + type Config = Config; + type Transport = TcpTransport; + + /// Create new [`TcpTransport`]. + fn new( + context: TransportHandle, + mut config: Self::Config, + ) -> crate::Result<(Self, Vec)> { + tracing::debug!( + target: LOG_TARGET, + listen_addresses = ?config.listen_addresses, + "start tcp transport", + ); + + // start tcp listeners for all listen addresses + let (listener, listen_addresses, dial_addresses) = TcpListener::new( + std::mem::replace(&mut config.listen_addresses, Vec::new()), + config.reuse_port, + ); + + Ok(( + Self { + listener, + config, + context, + dial_addresses, + canceled: HashSet::new(), + opened_raw: HashMap::new(), + pending_open: HashMap::new(), + pending_dials: HashMap::new(), + pending_connections: FuturesUnordered::new(), + pending_raw_connections: FuturesUnordered::new(), + }, + listen_addresses, + )) + } } impl Transport for TcpTransport { - fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection"); - - let (socket_address, peer) = listener::TcpListener::get_socket_address(&address)?; - let yamux_config = self.config.yamux_config.clone(); - let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; - let max_write_buffer_size = self.config.noise_write_buffer_size; - let connection_open_timeout = self.config.connection_open_timeout; - let substream_open_timeout = self.config.substream_open_timeout; - let dial_addresses = self.dial_addresses.clone(); - let keypair = self.context.keypair.clone(); - - self.pending_dials.insert(connection_id, address.clone()); - self.pending_connections.push(Box::pin(async move { - let (_, stream) = - TcpTransport::dial_peer(address, dial_addresses, connection_open_timeout) - .await - .map_err(|error| (connection_id, error))?; - - TcpConnection::open_connection( - connection_id, - keypair, - stream, - socket_address, - peer, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - connection_open_timeout, - substream_open_timeout, - ) - .await - .map_err(|error| (connection_id, error)) - })); - - Ok(()) - } - - fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let context = self - .pending_open - .remove(&connection_id) - .ok_or(Error::ConnectionDoesntExist(connection_id))?; - let protocol_set = self.context.protocol_set(connection_id); - let bandwidth_sink = self.context.bandwidth_sink.clone(); - let next_substream_id = self.context.next_substream_id.clone(); - - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - "start connection", - ); - - self.context.executor.run(Box::pin(async move { - if let Err(error) = - TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id) - .start() - .await - { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "connection exited with error", - ); - } - })); - - Ok(()) - } - - fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - self.canceled.insert(connection_id); - self.pending_open - .remove(&connection_id) - .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(())) - } - - fn open( - &mut self, - connection_id: ConnectionId, - addresses: Vec, - ) -> crate::Result<()> { - let mut futures: FuturesUnordered<_> = addresses - .into_iter() - .map(|address| { - let dial_addresses = self.dial_addresses.clone(); - let connection_open_timeout = self.config.connection_open_timeout; - - async move { - TcpTransport::dial_peer(address, dial_addresses, connection_open_timeout).await - } - }) - .collect(); - - self.pending_raw_connections.push(Box::pin(async move { - while let Some(result) = futures.next().await { - match result { - Ok((address, stream)) => return Ok((connection_id, address, stream)), - Err(error) => tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "failed to open connection", - ), - } - } - - Err(connection_id) - })); - - Ok(()) - } - - fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let (stream, address) = self - .opened_raw - .remove(&connection_id) - .ok_or(Error::ConnectionDoesntExist(connection_id))?; - - let (socket_address, peer) = listener::TcpListener::get_socket_address(&address)?; - let yamux_config = self.config.yamux_config.clone(); - let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; - let max_write_buffer_size = self.config.noise_write_buffer_size; - let connection_open_timeout = self.config.connection_open_timeout; - let substream_open_timeout = self.config.substream_open_timeout; - let keypair = self.context.keypair.clone(); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?address, - "negotiate connection", - ); - - self.pending_dials.insert(connection_id, address); - self.pending_connections.push(Box::pin(async move { - match tokio::time::timeout(connection_open_timeout, async move { - TcpConnection::negotiate_connection( - stream, - peer, - connection_id, - keypair, - Role::Dialer, - socket_address, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - substream_open_timeout, - ) - .await - .map_err(|error| (connection_id, error)) - }) - .await - { - Err(_) => Err((connection_id, Error::Timeout)), - Ok(Err(error)) => Err(error), - Ok(Ok(connection)) => Ok(connection), - } - })); - - Ok(()) - } - - fn cancel(&mut self, connection_id: ConnectionId) { - self.canceled.insert(connection_id); - } + fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection"); + + let (socket_address, peer) = listener::TcpListener::get_socket_address(&address)?; + let yamux_config = self.config.yamux_config.clone(); + let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; + let max_write_buffer_size = self.config.noise_write_buffer_size; + let connection_open_timeout = self.config.connection_open_timeout; + let substream_open_timeout = self.config.substream_open_timeout; + let dial_addresses = self.dial_addresses.clone(); + let keypair = self.context.keypair.clone(); + + self.pending_dials.insert(connection_id, address.clone()); + self.pending_connections.push(Box::pin(async move { + let (_, stream) = + TcpTransport::dial_peer(address, dial_addresses, connection_open_timeout) + .await + .map_err(|error| (connection_id, error))?; + + TcpConnection::open_connection( + connection_id, + keypair, + stream, + socket_address, + peer, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + connection_open_timeout, + substream_open_timeout, + ) + .await + .map_err(|error| (connection_id, error)) + })); + + Ok(()) + } + + fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let context = self + .pending_open + .remove(&connection_id) + .ok_or(Error::ConnectionDoesntExist(connection_id))?; + let protocol_set = self.context.protocol_set(connection_id); + let bandwidth_sink = self.context.bandwidth_sink.clone(); + let next_substream_id = self.context.next_substream_id.clone(); + + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + "start connection", + ); + + self.context.executor.run(Box::pin(async move { + if let Err(error) = + TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id) + .start() + .await + { + tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "connection exited with error", + ); + } + })); + + Ok(()) + } + + fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + self.canceled.insert(connection_id); + self.pending_open + .remove(&connection_id) + .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(())) + } + + fn open( + &mut self, + connection_id: ConnectionId, + addresses: Vec, + ) -> crate::Result<()> { + let mut futures: FuturesUnordered<_> = addresses + .into_iter() + .map(|address| { + let dial_addresses = self.dial_addresses.clone(); + let connection_open_timeout = self.config.connection_open_timeout; + + async move { + TcpTransport::dial_peer(address, dial_addresses, connection_open_timeout).await + } + }) + .collect(); + + self.pending_raw_connections.push(Box::pin(async move { + while let Some(result) = futures.next().await { + match result { + Ok((address, stream)) => return Ok((connection_id, address, stream)), + Err(error) => tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "failed to open connection", + ), + } + } + + Err(connection_id) + })); + + Ok(()) + } + + fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let (stream, address) = self + .opened_raw + .remove(&connection_id) + .ok_or(Error::ConnectionDoesntExist(connection_id))?; + + let (socket_address, peer) = listener::TcpListener::get_socket_address(&address)?; + let yamux_config = self.config.yamux_config.clone(); + let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; + let max_write_buffer_size = self.config.noise_write_buffer_size; + let connection_open_timeout = self.config.connection_open_timeout; + let substream_open_timeout = self.config.substream_open_timeout; + let keypair = self.context.keypair.clone(); + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?address, + "negotiate connection", + ); + + self.pending_dials.insert(connection_id, address); + self.pending_connections.push(Box::pin(async move { + match tokio::time::timeout(connection_open_timeout, async move { + TcpConnection::negotiate_connection( + stream, + peer, + connection_id, + keypair, + Role::Dialer, + socket_address, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + substream_open_timeout, + ) + .await + .map_err(|error| (connection_id, error)) + }) + .await + { + Err(_) => Err((connection_id, Error::Timeout)), + Ok(Err(error)) => Err(error), + Ok(Ok(connection)) => Ok(connection), + } + })); + + Ok(()) + } + + fn cancel(&mut self, connection_id: ConnectionId) { + self.canceled.insert(connection_id); + } } impl Stream for TcpTransport { - type Item = TransportEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - while let Poll::Ready(event) = self.listener.poll_next_unpin(cx) { - match event { - None | Some(Err(_)) => return Poll::Ready(None), - Some(Ok((connection, address))) => { - self.on_inbound_connection(connection, address); - }, - } - } - - while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); - - if !self.canceled.remove(&connection_id) { - self.opened_raw.insert(connection_id, (stream, address.clone())); - - return Poll::Ready(Some(TransportEvent::ConnectionOpened { - connection_id, - address, - })); - } - }, - Err(connection_id) => - if !self.canceled.remove(&connection_id) { - return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id })); - }, - } - } - - while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) { - match connection { - Ok(connection) => { - let peer = connection.peer(); - let endpoint = connection.endpoint(); - self.pending_open.insert(connection.connection_id(), connection); - - return Poll::Ready(Some(TransportEvent::ConnectionEstablished { - peer, - endpoint, - })); - }, - Err((connection_id, error)) => { - if let Some(address) = self.pending_dials.remove(&connection_id) { - return Poll::Ready(Some(TransportEvent::DialFailure { - connection_id, - address, - error, - })); - } - }, - } - } - - Poll::Pending - } + type Item = TransportEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Poll::Ready(event) = self.listener.poll_next_unpin(cx) { + match event { + None | Some(Err(_)) => return Poll::Ready(None), + Some(Ok((connection, address))) => { + self.on_inbound_connection(connection, address); + } + } + } + + while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { + match result { + Ok((connection_id, address, stream)) => { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?address, + canceled = self.canceled.contains(&connection_id), + "connection opened", + ); + + if !self.canceled.remove(&connection_id) { + self.opened_raw.insert(connection_id, (stream, address.clone())); + + return Poll::Ready(Some(TransportEvent::ConnectionOpened { + connection_id, + address, + })); + } + } + Err(connection_id) => + if !self.canceled.remove(&connection_id) { + return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id })); + }, + } + } + + while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) { + match connection { + Ok(connection) => { + let peer = connection.peer(); + let endpoint = connection.endpoint(); + self.pending_open.insert(connection.connection_id(), connection); + + return Poll::Ready(Some(TransportEvent::ConnectionEstablished { + peer, + endpoint, + })); + } + Err((connection_id, error)) => { + if let Some(address) = self.pending_dials.remove(&connection_id) { + return Poll::Ready(Some(TransportEvent::DialFailure { + connection_id, + address, + error, + })); + } + } + } + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use crate::{ - codec::ProtocolCodec, - crypto::ed25519::Keypair, - executor::DefaultExecutor, - transport::manager::{ProtocolContext, SupportedTransport, TransportManager}, - types::protocol::ProtocolName, - BandwidthSink, PeerId, - }; - use multiaddr::Protocol; - use multihash::Multihash; - use std::{collections::HashSet, sync::Arc}; - use tokio::sync::mpsc::channel; - - #[tokio::test] - async fn connect_and_accept_works() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let keypair1 = Keypair::generate(); - let (tx1, _rx1) = channel(64); - let (event_tx1, _event_rx1) = channel(64); - let bandwidth_sink = BandwidthSink::new(); - - let handle1 = crate::transport::manager::TransportHandle { - executor: Arc::new(DefaultExecutor {}), - protocol_names: Vec::new(), - next_substream_id: Default::default(), - next_connection_id: Default::default(), - keypair: keypair1.clone(), - tx: event_tx1, - bandwidth_sink: bandwidth_sink.clone(), - - protocols: HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx1, - codec: ProtocolCodec::Identity(32), - fallback_names: Vec::new(), - }, - )]), - }; - let transport_config1 = Config { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }; - - let (mut transport1, listen_addresses) = - TcpTransport::new(handle1, transport_config1).unwrap(); - let listen_address = listen_addresses[0].clone(); - - let keypair2 = Keypair::generate(); - let (tx2, _rx2) = channel(64); - let (event_tx2, _event_rx2) = channel(64); - - let handle2 = crate::transport::manager::TransportHandle { - executor: Arc::new(DefaultExecutor {}), - protocol_names: Vec::new(), - next_substream_id: Default::default(), - next_connection_id: Default::default(), - keypair: keypair2.clone(), - tx: event_tx2, - bandwidth_sink: bandwidth_sink.clone(), - - protocols: HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx2, - codec: ProtocolCodec::Identity(32), - fallback_names: Vec::new(), - }, - )]), - }; - let transport_config2 = Config { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }; - - let (mut transport2, _) = TcpTransport::new(handle2, transport_config2).unwrap(); - transport2.dial(ConnectionId::new(), listen_address).unwrap(); - - let (res1, res2) = tokio::join!(transport1.next(), transport2.next()); - - assert!(std::matches!(res1, Some(TransportEvent::ConnectionEstablished { .. }))); - assert!(std::matches!(res2, Some(TransportEvent::ConnectionEstablished { .. }))); - } - - #[tokio::test] - async fn dial_failure() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let keypair1 = Keypair::generate(); - let (tx1, _rx1) = channel(64); - let (event_tx1, mut event_rx1) = channel(64); - let bandwidth_sink = BandwidthSink::new(); - - let handle1 = crate::transport::manager::TransportHandle { - executor: Arc::new(DefaultExecutor {}), - protocol_names: Vec::new(), - next_substream_id: Default::default(), - next_connection_id: Default::default(), - keypair: keypair1.clone(), - tx: event_tx1, - bandwidth_sink: bandwidth_sink.clone(), - - protocols: HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx1, - codec: ProtocolCodec::Identity(32), - fallback_names: Vec::new(), - }, - )]), - }; - let (mut transport1, _) = TcpTransport::new(handle1, Default::default()).unwrap(); - - tokio::spawn(async move { - while let Some(event) = transport1.next().await { - match event { - TransportEvent::ConnectionEstablished { .. } => {}, - TransportEvent::ConnectionClosed { .. } => {}, - TransportEvent::DialFailure { .. } => {}, - TransportEvent::ConnectionOpened { .. } => {}, - TransportEvent::OpenFailure { .. } => {}, - } - } - }); - - let keypair2 = Keypair::generate(); - let (tx2, _rx2) = channel(64); - let (event_tx2, _event_rx2) = channel(64); - - let handle2 = crate::transport::manager::TransportHandle { - executor: Arc::new(DefaultExecutor {}), - protocol_names: Vec::new(), - next_substream_id: Default::default(), - next_connection_id: Default::default(), - keypair: keypair2.clone(), - tx: event_tx2, - bandwidth_sink: bandwidth_sink.clone(), - - protocols: HashMap::from_iter([( - ProtocolName::from("/notif/1"), - ProtocolContext { - tx: tx2, - codec: ProtocolCodec::Identity(32), - fallback_names: Vec::new(), - }, - )]), - }; - - let (mut transport2, _) = TcpTransport::new(handle2, Default::default()).unwrap(); - - let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into()); - let peer2: PeerId = PeerId::from_public_key(&keypair2.public().into()); - - tracing::info!(target: LOG_TARGET, "peer1 {peer1}, peer2 {peer2}"); - - let address = Multiaddr::empty() - .with(Protocol::Ip6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer1.to_bytes()).unwrap())); - - transport2.dial(ConnectionId::new(), address).unwrap(); - - // spawn the other conection in the background as it won't return anything - tokio::spawn(async move { - loop { - let _ = event_rx1.recv().await; - } - }); - - assert!(std::matches!(transport2.next().await, Some(TransportEvent::DialFailure { .. }))); - } - - #[tokio::test] - async fn dial_error_reported_for_outbound_connections() { - let (mut manager, _handle) = TransportManager::new( - Keypair::generate(), - HashSet::new(), - BandwidthSink::new(), - 8usize, - ); - let handle = manager.transport_handle(Arc::new(DefaultExecutor {})); - manager.register_transport( - SupportedTransport::Tcp, - Box::new(crate::transport::dummy::DummyTransport::new()), - ); - let (mut transport, _) = TcpTransport::new( - handle, - Config { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], - ..Default::default() - }, - ) - .unwrap(); - - let keypair = Keypair::generate(); - let peer_id = PeerId::from_public_key(&keypair.public().into()); - let multiaddr = Multiaddr::empty() - .with(Protocol::Ip4(std::net::Ipv4Addr::new(255, 254, 253, 252))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&peer_id.to_bytes()).unwrap())); - manager.dial_address(multiaddr.clone()).await.unwrap(); - - assert!(transport.pending_dials.is_empty()); - - match transport.dial(ConnectionId::from(0usize), multiaddr) { - Ok(()) => {}, - _ => panic!("invalid result for `on_dial_peer()`"), - } - - assert!(!transport.pending_dials.is_empty()); - transport - .pending_connections - .push(Box::pin(async move { Err((ConnectionId::from(0usize), Error::Unknown)) })); - - assert!(std::matches!(transport.next().await, Some(TransportEvent::DialFailure { .. }))); - assert!(transport.pending_dials.is_empty()); - } + use super::*; + use crate::{ + codec::ProtocolCodec, + crypto::ed25519::Keypair, + executor::DefaultExecutor, + transport::manager::{ProtocolContext, SupportedTransport, TransportManager}, + types::protocol::ProtocolName, + BandwidthSink, PeerId, + }; + use multiaddr::Protocol; + use multihash::Multihash; + use std::{collections::HashSet, sync::Arc}; + use tokio::sync::mpsc::channel; + + #[tokio::test] + async fn connect_and_accept_works() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair1 = Keypair::generate(); + let (tx1, _rx1) = channel(64); + let (event_tx1, _event_rx1) = channel(64); + let bandwidth_sink = BandwidthSink::new(); + + let handle1 = crate::transport::manager::TransportHandle { + executor: Arc::new(DefaultExecutor {}), + protocol_names: Vec::new(), + next_substream_id: Default::default(), + next_connection_id: Default::default(), + keypair: keypair1.clone(), + tx: event_tx1, + bandwidth_sink: bandwidth_sink.clone(), + + protocols: HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx1, + codec: ProtocolCodec::Identity(32), + fallback_names: Vec::new(), + }, + )]), + }; + let transport_config1 = Config { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }; + + let (mut transport1, listen_addresses) = + TcpTransport::new(handle1, transport_config1).unwrap(); + let listen_address = listen_addresses[0].clone(); + + let keypair2 = Keypair::generate(); + let (tx2, _rx2) = channel(64); + let (event_tx2, _event_rx2) = channel(64); + + let handle2 = crate::transport::manager::TransportHandle { + executor: Arc::new(DefaultExecutor {}), + protocol_names: Vec::new(), + next_substream_id: Default::default(), + next_connection_id: Default::default(), + keypair: keypair2.clone(), + tx: event_tx2, + bandwidth_sink: bandwidth_sink.clone(), + + protocols: HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx2, + codec: ProtocolCodec::Identity(32), + fallback_names: Vec::new(), + }, + )]), + }; + let transport_config2 = Config { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }; + + let (mut transport2, _) = TcpTransport::new(handle2, transport_config2).unwrap(); + transport2.dial(ConnectionId::new(), listen_address).unwrap(); + + let (res1, res2) = tokio::join!(transport1.next(), transport2.next()); + + assert!(std::matches!( + res1, + Some(TransportEvent::ConnectionEstablished { .. }) + )); + assert!(std::matches!( + res2, + Some(TransportEvent::ConnectionEstablished { .. }) + )); + } + + #[tokio::test] + async fn dial_failure() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair1 = Keypair::generate(); + let (tx1, _rx1) = channel(64); + let (event_tx1, mut event_rx1) = channel(64); + let bandwidth_sink = BandwidthSink::new(); + + let handle1 = crate::transport::manager::TransportHandle { + executor: Arc::new(DefaultExecutor {}), + protocol_names: Vec::new(), + next_substream_id: Default::default(), + next_connection_id: Default::default(), + keypair: keypair1.clone(), + tx: event_tx1, + bandwidth_sink: bandwidth_sink.clone(), + + protocols: HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx1, + codec: ProtocolCodec::Identity(32), + fallback_names: Vec::new(), + }, + )]), + }; + let (mut transport1, _) = TcpTransport::new(handle1, Default::default()).unwrap(); + + tokio::spawn(async move { + while let Some(event) = transport1.next().await { + match event { + TransportEvent::ConnectionEstablished { .. } => {} + TransportEvent::ConnectionClosed { .. } => {} + TransportEvent::DialFailure { .. } => {} + TransportEvent::ConnectionOpened { .. } => {} + TransportEvent::OpenFailure { .. } => {} + } + } + }); + + let keypair2 = Keypair::generate(); + let (tx2, _rx2) = channel(64); + let (event_tx2, _event_rx2) = channel(64); + + let handle2 = crate::transport::manager::TransportHandle { + executor: Arc::new(DefaultExecutor {}), + protocol_names: Vec::new(), + next_substream_id: Default::default(), + next_connection_id: Default::default(), + keypair: keypair2.clone(), + tx: event_tx2, + bandwidth_sink: bandwidth_sink.clone(), + + protocols: HashMap::from_iter([( + ProtocolName::from("/notif/1"), + ProtocolContext { + tx: tx2, + codec: ProtocolCodec::Identity(32), + fallback_names: Vec::new(), + }, + )]), + }; + + let (mut transport2, _) = TcpTransport::new(handle2, Default::default()).unwrap(); + + let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into()); + let peer2: PeerId = PeerId::from_public_key(&keypair2.public().into()); + + tracing::info!(target: LOG_TARGET, "peer1 {peer1}, peer2 {peer2}"); + + let address = Multiaddr::empty() + .with(Protocol::Ip6(std::net::Ipv6Addr::new( + 0, 0, 0, 0, 0, 0, 0, 1, + ))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer1.to_bytes()).unwrap(), + )); + + transport2.dial(ConnectionId::new(), address).unwrap(); + + // spawn the other conection in the background as it won't return anything + tokio::spawn(async move { + loop { + let _ = event_rx1.recv().await; + } + }); + + assert!(std::matches!( + transport2.next().await, + Some(TransportEvent::DialFailure { .. }) + )); + } + + #[tokio::test] + async fn dial_error_reported_for_outbound_connections() { + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ); + let handle = manager.transport_handle(Arc::new(DefaultExecutor {})); + manager.register_transport( + SupportedTransport::Tcp, + Box::new(crate::transport::dummy::DummyTransport::new()), + ); + let (mut transport, _) = TcpTransport::new( + handle, + Config { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }, + ) + .unwrap(); + + let keypair = Keypair::generate(); + let peer_id = PeerId::from_public_key(&keypair.public().into()); + let multiaddr = Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(255, 254, 253, 252))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer_id.to_bytes()).unwrap(), + )); + manager.dial_address(multiaddr.clone()).await.unwrap(); + + assert!(transport.pending_dials.is_empty()); + + match transport.dial(ConnectionId::from(0usize), multiaddr) { + Ok(()) => {} + _ => panic!("invalid result for `on_dial_peer()`"), + } + + assert!(!transport.pending_dials.is_empty()); + transport.pending_connections.push(Box::pin(async move { + Err((ConnectionId::from(0usize), Error::Unknown)) + })); + + assert!(std::matches!( + transport.next().await, + Some(TransportEvent::DialFailure { .. }) + )); + assert!(transport.pending_dials.is_empty()); + } } diff --git a/src/transport/tcp/substream.rs b/src/transport/tcp/substream.rs index 575501e0..0ce8a779 100644 --- a/src/transport/tcp/substream.rs +++ b/src/transport/tcp/substream.rs @@ -24,9 +24,9 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::compat::Compat; use std::{ - io, - pin::Pin, - task::{Context, Poll}, + io, + pin::Pin, + task::{Context, Poll}, }; /// Substream that holds the inner substream provided by the transport @@ -35,84 +35,88 @@ use std::{ /// `BandwidthSink` is used to meter inbound/outbound bytes. #[derive(Debug)] pub struct Substream { - /// Underlying socket. - io: Compat, + /// Underlying socket. + io: Compat, - /// Bandwidth sink. - bandwidth_sink: BandwidthSink, + /// Bandwidth sink. + bandwidth_sink: BandwidthSink, - /// Connection permit. - _permit: Permit, + /// Connection permit. + _permit: Permit, } impl Substream { - /// Create new [`Substream`]. - pub fn new( - io: Compat, - bandwidth_sink: BandwidthSink, - _permit: Permit, - ) -> Self { - Self { io, bandwidth_sink, _permit } - } + /// Create new [`Substream`]. + pub fn new( + io: Compat, + bandwidth_sink: BandwidthSink, + _permit: Permit, + ) -> Self { + Self { + io, + bandwidth_sink, + _permit, + } + } } impl AsyncRead for Substream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.io).poll_read(cx, buf)) { - Err(error) => Poll::Ready(Err(error)), - Ok(res) => { - self.bandwidth_sink.increase_inbound(buf.filled().len()); - Poll::Ready(Ok(res)) - }, - } - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.io).poll_read(cx, buf)) { + Err(error) => Poll::Ready(Err(error)), + Ok(res) => { + self.bandwidth_sink.increase_inbound(buf.filled().len()); + Poll::Ready(Ok(res)) + } + } + } } impl AsyncWrite for Substream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.io).poll_write(cx, buf)) { - Err(error) => Poll::Ready(Err(error)), - Ok(nwritten) => { - self.bandwidth_sink.increase_outbound(nwritten); - Poll::Ready(Ok(nwritten)) - }, - } - } + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.io).poll_write(cx, buf)) { + Err(error) => Poll::Ready(Err(error)), + Ok(nwritten) => { + self.bandwidth_sink.increase_outbound(nwritten); + Poll::Ready(Ok(nwritten)) + } + } + } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_flush(cx) - } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.io).poll_shutdown(cx) - } + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.io).poll_shutdown(cx) + } - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.io).poll_write_vectored(cx, bufs)) { - Err(error) => Poll::Ready(Err(error)), - Ok(nwritten) => { - self.bandwidth_sink.increase_outbound(nwritten); - Poll::Ready(Ok(nwritten)) - }, - } - } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.io).poll_write_vectored(cx, bufs)) { + Err(error) => Poll::Ready(Err(error)), + Ok(nwritten) => { + self.bandwidth_sink.increase_outbound(nwritten); + Poll::Ready(Ok(nwritten)) + } + } + } - fn is_write_vectored(&self) -> bool { - self.io.is_write_vectored() - } + fn is_write_vectored(&self) -> bool { + self.io.is_write_vectored() + } } diff --git a/src/transport/webrtc/config.rs b/src/transport/webrtc/config.rs index 5de83274..526829d2 100644 --- a/src/transport/webrtc/config.rs +++ b/src/transport/webrtc/config.rs @@ -25,6 +25,6 @@ use multiaddr::Multiaddr; /// WebRTC transport configuration. #[derive(Debug)] pub struct Config { - /// WebRTC listening address. - pub listen_addresses: Vec, + /// WebRTC listening address. + pub listen_addresses: Vec, } diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index 7a2892b2..bf49fded 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -21,39 +21,39 @@ #![allow(unused)] use crate::{ - config::Role, - crypto::{ed25519::Keypair, noise::NoiseContext}, - error::Error, - multistream_select::{listener_negotiate, DialerState, HandshakeResult}, - protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, - substream::Substream, - transport::{ - webrtc::{ - substream::SubstreamBackend, - util::{SubstreamContext, WebRtcMessage}, - WebRtcEvent, - }, - Endpoint, - }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - PeerId, + config::Role, + crypto::{ed25519::Keypair, noise::NoiseContext}, + error::Error, + multistream_select::{listener_negotiate, DialerState, HandshakeResult}, + protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, + substream::Substream, + transport::{ + webrtc::{ + substream::SubstreamBackend, + util::{SubstreamContext, WebRtcMessage}, + WebRtcEvent, + }, + Endpoint, + }, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + PeerId, }; use futures::StreamExt; use multiaddr::{multihash::Multihash, Multiaddr, Protocol}; use str0m::{ - change::Fingerprint, - channel::{ChannelConfig, ChannelData, ChannelId}, - net::Receive, - Event, IceConnectionState, Input, Output, Rtc, + change::Fingerprint, + channel::{ChannelConfig, ChannelData, ChannelId}, + net::Receive, + Event, IceConnectionState, Input, Output, Rtc, }; use tokio::{net::UdpSocket, sync::mpsc::Receiver}; use std::{ - collections::HashMap, - net::SocketAddr, - sync::Arc, - time::{Duration, Instant}, + collections::HashMap, + net::SocketAddr, + sync::Arc, + time::{Duration, Instant}, }; /// Logging target for the file. @@ -61,658 +61,663 @@ const LOG_TARGET: &str = "litep2p::webrtc::connection"; /// Create Noise prologue. fn noise_prologue_new(local_fingerprint: Vec, remote_fingerprint: Vec) -> Vec { - const PREFIX: &[u8] = b"libp2p-webrtc-noise:"; - let mut prologue = - Vec::with_capacity(PREFIX.len() + local_fingerprint.len() + remote_fingerprint.len()); - prologue.extend_from_slice(PREFIX); - prologue.extend_from_slice(&remote_fingerprint); - prologue.extend_from_slice(&local_fingerprint); - - prologue + const PREFIX: &[u8] = b"libp2p-webrtc-noise:"; + let mut prologue = + Vec::with_capacity(PREFIX.len() + local_fingerprint.len() + remote_fingerprint.len()); + prologue.extend_from_slice(PREFIX); + prologue.extend_from_slice(&remote_fingerprint); + prologue.extend_from_slice(&local_fingerprint); + + prologue } /// WebRTC connection state. #[derive(Debug)] enum State { - /// Connection state is poisoned. - Poisoned, - - /// Connection state is closed. - Closed, - - /// Connection state is opened. - Opened { - /// Noise handshaker. - handshaker: NoiseContext, - }, - - /// Handshake has been sent - HandshakeSent { - /// Noise handshaker. - handshaker: NoiseContext, - }, - - /// Connection is open. - Open { - /// Remote peer ID. - peer: PeerId, - }, + /// Connection state is poisoned. + Poisoned, + + /// Connection state is closed. + Closed, + + /// Connection state is opened. + Opened { + /// Noise handshaker. + handshaker: NoiseContext, + }, + + /// Handshake has been sent + HandshakeSent { + /// Noise handshaker. + handshaker: NoiseContext, + }, + + /// Connection is open. + Open { + /// Remote peer ID. + peer: PeerId, + }, } /// Substream state. #[derive(Debug)] enum SubstreamState { - /// Substream state is poisoned. - Poisoned, + /// Substream state is poisoned. + Poisoned, - /// Substream (outbound) is opening. - Opening { - /// Protocol. - protocol: ProtocolName, + /// Substream (outbound) is opening. + Opening { + /// Protocol. + protocol: ProtocolName, - /// Negotiated fallback. - fallback: Option, + /// Negotiated fallback. + fallback: Option, - /// `multistream-select` dialer state. - dialer_state: DialerState, + /// `multistream-select` dialer state. + dialer_state: DialerState, - /// Substream ID, - substream_id: SubstreamId, + /// Substream ID, + substream_id: SubstreamId, - /// Connection permit. - permit: Permit, - }, + /// Connection permit. + permit: Permit, + }, - /// Substream is open. - Open { - /// Substream ID. - substream_id: SubstreamId, + /// Substream is open. + Open { + /// Substream ID. + substream_id: SubstreamId, - /// Substream. - substream: SubstreamContext, + /// Substream. + substream: SubstreamContext, - /// Connection permit. - permit: Permit, - }, + /// Connection permit. + permit: Permit, + }, } /// WebRTC connection. // TODO: too much stuff, refactor? pub(super) struct WebRtcConnection { - /// Connection ID. - pub(super) connection_id: ConnectionId, + /// Connection ID. + pub(super) connection_id: ConnectionId, - /// `str0m` WebRTC object. - pub(super) rtc: Rtc, + /// `str0m` WebRTC object. + pub(super) rtc: Rtc, - /// Noise channel ID. - _noise_channel_id: ChannelId, + /// Noise channel ID. + _noise_channel_id: ChannelId, - /// Identity keypair. - id_keypair: Keypair, + /// Identity keypair. + id_keypair: Keypair, - /// Connection state. - state: State, + /// Connection state. + state: State, - /// Protocol set. - protocol_set: ProtocolSet, + /// Protocol set. + protocol_set: ProtocolSet, - /// Peer address - peer_address: SocketAddr, + /// Peer address + peer_address: SocketAddr, - /// Local address. - local_address: SocketAddr, + /// Local address. + local_address: SocketAddr, - /// Transport socket. - socket: Arc, + /// Transport socket. + socket: Arc, - /// RX channel for receiving datagrams from the transport. - dgram_rx: Receiver>, + /// RX channel for receiving datagrams from the transport. + dgram_rx: Receiver>, - /// Substream backend. - backend: SubstreamBackend, + /// Substream backend. + backend: SubstreamBackend, - /// Next substream ID. - substream_id: SubstreamId, + /// Next substream ID. + substream_id: SubstreamId, - /// Pending outbound substreams. - pending_outbound: HashMap, SubstreamId, Permit)>, + /// Pending outbound substreams. + pending_outbound: HashMap, SubstreamId, Permit)>, - /// Open substreams. - substreams: HashMap, + /// Open substreams. + substreams: HashMap, } impl WebRtcConnection { - pub(super) fn new( - rtc: Rtc, - connection_id: ConnectionId, - _noise_channel_id: ChannelId, - id_keypair: Keypair, - protocol_set: ProtocolSet, - peer_address: SocketAddr, - local_address: SocketAddr, - socket: Arc, - dgram_rx: Receiver>, - ) -> WebRtcConnection { - WebRtcConnection { - rtc, - socket, - dgram_rx, - protocol_set, - id_keypair, - peer_address, - local_address, - connection_id, - _noise_channel_id, - state: State::Closed, - substreams: HashMap::new(), - backend: SubstreamBackend::new(), - substream_id: SubstreamId::new(), - pending_outbound: HashMap::new(), - } - } - - pub(super) async fn poll_output(&mut self) -> crate::Result { - match self.rtc.poll_output() { - Ok(output) => self.handle_output(output).await, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - connection_id = ?self.connection_id, - ?error, - "`WebRtcConnection::poll_output()` failed", - ); - return Err(Error::WebRtc(error)); - }, - } - } - - /// Handle data received from peer. - pub(super) async fn on_input(&mut self, buffer: Vec) -> crate::Result<()> { - let message = Input::Receive( - Instant::now(), - Receive { - source: self.peer_address, - destination: self.local_address, - contents: buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?, - }, - ); - - match self.rtc.accepts(&message) { - true => self.rtc.handle_input(message).map_err(|error| { - tracing::debug!(target: LOG_TARGET, source = ?self.peer_address, ?error, "failed to handle data"); - Error::InputRejected - }), - false => return Err(Error::InputRejected), - } - } - - async fn handle_output(&mut self, output: Output) -> crate::Result { - match output { - Output::Transmit(transmit) => { - self.socket - .send_to(&transmit.contents, transmit.destination) - .await - .expect("send to succeed"); - Ok(WebRtcEvent::Noop) - }, - Output::Timeout(t) => Ok(WebRtcEvent::Timeout(t)), - Output::Event(e) => match e { - Event::IceConnectionStateChange(v) => { - if v == IceConnectionState::Disconnected { - tracing::debug!(target: LOG_TARGET, "ice connection closed"); - return Err(Error::Disconnected); - } - Ok(WebRtcEvent::Noop) - }, - Event::ChannelOpen(cid, name) => { - // TODO: remove, report issue to smoldot - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - self.on_channel_open(cid, name).map(|_| WebRtcEvent::Noop) - }, - Event::ChannelData(data) => self.on_channel_data(data).await, - Event::ChannelClose(channel_id) => { - // TODO: notify the protocol - tracing::debug!(target: LOG_TARGET, ?channel_id, "channel closed"); - Ok(WebRtcEvent::Noop) - }, - Event::Connected => { - match std::mem::replace(&mut self.state, State::Poisoned) { - State::Closed => { - let remote_fingerprint = self.remote_fingerprint(); - let local_fingerprint = self.local_fingerprint(); - - let handshaker = NoiseContext::with_prologue( - &self.id_keypair, - noise_prologue_new(local_fingerprint, remote_fingerprint), - ); - - self.state = State::Opened { handshaker }; - }, - state => { - tracing::debug!( - target: LOG_TARGET, - ?state, - "invalid state for connection" - ); - return Err(Error::InvalidState); - }, - } - Ok(WebRtcEvent::Noop) - }, - event => { - tracing::warn!(target: LOG_TARGET, ?event, "unhandled event"); - Ok(WebRtcEvent::Noop) - }, - }, - } - } - - /// Get remote fingerprint to bytes. - fn remote_fingerprint(&mut self) -> Vec { - let fingerprint = self - .rtc - .direct_api() - .remote_dtls_fingerprint() - .clone() - .expect("fingerprint to exist"); - Self::fingerprint_to_bytes(&fingerprint) - } - - /// Get local fingerprint as bytes. - fn local_fingerprint(&mut self) -> Vec { - Self::fingerprint_to_bytes(&self.rtc.direct_api().local_dtls_fingerprint()) - } - - /// Convert `Fingerprint` to bytes. - fn fingerprint_to_bytes(fingerprint: &Fingerprint) -> Vec { - const MULTIHASH_SHA256_CODE: u64 = 0x12; - Multihash::wrap(MULTIHASH_SHA256_CODE, &fingerprint.bytes) - .expect("fingerprint's len to be 32 bytes") - .to_bytes() - } - - fn on_noise_channel_open(&mut self) -> crate::Result<()> { - tracing::trace!(target: LOG_TARGET, "send initial noise handshake"); - - let State::Opened { mut handshaker } = std::mem::replace(&mut self.state, State::Poisoned) - else { - return Err(Error::InvalidState); - }; - - // create first noise handshake and send it to remote peer - let payload = WebRtcMessage::encode(handshaker.first_message(Role::Dialer), None); - - self.rtc - .channel(self._noise_channel_id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, payload.as_slice()) - .map_err(|error| Error::WebRtc(error))?; - - self.state = State::HandshakeSent { handshaker }; - Ok(()) - } - - fn on_channel_open(&mut self, channel_id: ChannelId, name: String) -> crate::Result<()> { - tracing::debug!(target: LOG_TARGET, ?channel_id, channel_name = ?name, "channel opened"); - - if channel_id == self._noise_channel_id { - return self.on_noise_channel_open(); - } - - match self.pending_outbound.remove(&channel_id) { - None => { - tracing::trace!(target: LOG_TARGET, ?channel_id, "remote opened a substream"); - }, - Some((protocol, fallback_names, substream_id, permit)) => { - tracing::trace!(target: LOG_TARGET, ?channel_id, "dialer negotiate protocol"); - - let (dialer_state, message) = - DialerState::propose(protocol.clone(), fallback_names)?; - let message = WebRtcMessage::encode(message, None); - - self.rtc - .channel(channel_id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, message.as_ref()) - .map_err(|error| Error::WebRtc(error))?; - - self.substreams.insert( - channel_id, - SubstreamState::Opening { - protocol, - fallback: None, - substream_id, - dialer_state, - permit, - }, - ); - }, - } - - Ok(()) - } - - async fn on_noise_channel_data(&mut self, data: Vec) -> crate::Result { - tracing::trace!(target: LOG_TARGET, "handle noise handshake reply"); - - let State::HandshakeSent { mut handshaker } = - std::mem::replace(&mut self.state, State::Poisoned) - else { - return Err(Error::InvalidState); - }; - - let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; - let public_key = handshaker.get_remote_public_key(&message)?; - let remote_peer_id = PeerId::from_public_key(&public_key); - - tracing::trace!( - target: LOG_TARGET, - ?remote_peer_id, - "remote reply parsed successfully" - ); - - // create second noise handshake message and send it to remote - let payload = WebRtcMessage::encode(handshaker.second_message(), None); - - let mut channel = - self.rtc.channel(self._noise_channel_id).ok_or(Error::ChannelDoesntExist)?; - - channel.write(true, payload.as_slice()).map_err(|error| Error::WebRtc(error))?; - - let remote_fingerprint = self - .rtc - .direct_api() - .remote_dtls_fingerprint() - .clone() - .expect("fingerprint to exist") - .bytes; - - const MULTIHASH_SHA256_CODE: u64 = 0x12; - let certificate = Multihash::wrap(MULTIHASH_SHA256_CODE, &remote_fingerprint) - .expect("fingerprint's len to be 32 bytes"); - - let address = Multiaddr::empty() - .with(Protocol::from(self.peer_address.ip())) - .with(Protocol::Udp(self.peer_address.port())) - .with(Protocol::WebRTC) - .with(Protocol::Certhash(certificate)) - .with(Protocol::P2p(PeerId::from(public_key).into())); - - self.protocol_set - .report_connection_established( - remote_peer_id, - Endpoint::listener(address, self.connection_id), - ) - .await?; - - self.state = State::Open { peer: remote_peer_id }; - - Ok(WebRtcEvent::Noop) - } - - /// Report open substream to the protocol. - async fn report_open_substream( - &mut self, - channel_id: ChannelId, - protocol: ProtocolName, - ) -> crate::Result { - // let substream_id = self.substream_id.next(); - // let (mut substream, tx) = self.backend.substream(channel_id); - // let substream: Box = { - // substream.apply_codec(self.protocol_set.protocol_codec(&protocol)); - // Box::new(substream) - // }; - // let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - - // self.substreams.insert( - // channel_id, - // SubstreamState::Open { - // substream_id, - // substream: SubstreamContext::new(channel_id, tx), - // permit, - // }, - // ); - // TODO: fix - - if let State::Open { peer, .. } = &mut self.state { - // let _ = self - // .protocol_set - // .report_substream_open(*peer, protocol.clone(), Direction::Inbound, substream) - // .await; - todo!(); - } - - Ok(WebRtcEvent::Noop) - } - - /// Negotiate protocol for the channel - async fn listener_negotiate_protocol(&mut self, d: ChannelData) -> crate::Result { - tracing::trace!(target: LOG_TARGET, channel_id = ?d.id, "negotiate protocol for the channel"); - - let payload = WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; - - let (protocol, response) = - listener_negotiate(&mut self.protocol_set.protocols().iter(), payload.into())?; - - let message = WebRtcMessage::encode(response.to_vec(), None); - - self.rtc - .channel(d.id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, message.as_ref()) - .map_err(|error| Error::WebRtc(error))?; - - self.report_open_substream(d.id, protocol).await - - // let substream_id = self.substream_id.next(); - // let (mut substream, tx) = self.backend.substream(d.id); - // let substream: Box = { - // substream.apply_codec(self.protocol_set.protocol_codec(&protocol)); - // Box::new(substream) - // }; - // let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - - // self.substreams.insert( - // d.id, - // SubstreamState::Open { - // substream_id, - // substream: SubstreamContext::new(d.id, tx), - // permit, - // }, - // ); - - // if let State::Open { peer, .. } = &mut self.state { - // let _ = self - // .protocol_set - // .report_substream_open(*peer, protocol.clone(), Direction::Inbound, substream) - // .await; - // } - // Ok(WebRtcEvent::Noop) - } - - async fn on_channel_data(&mut self, d: ChannelData) -> crate::Result { - match &self.state { - State::HandshakeSent { .. } => self.on_noise_channel_data(d.data).await, - State::Open { .. } => { - match self.substreams.get_mut(&d.id) { - None => match self.listener_negotiate_protocol(d).await { - Ok(_) => { - tracing::debug!(target: LOG_TARGET, "protocol negotiated for the channel"); - - Ok(WebRtcEvent::Noop) - }, - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to negotiate protocol"); - - // TODO: close channel - Ok(WebRtcEvent::Noop) - }, - }, - Some(SubstreamState::Poisoned) => return Err(Error::ConnectionClosed), - Some(SubstreamState::Opening { ref mut dialer_state, .. }) => { - tracing::info!(target: LOG_TARGET, "try to decode message"); - let message = - WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; - tracing::info!(target: LOG_TARGET, "decoded successfully"); - - match dialer_state.register_response(message) { - Ok(HandshakeResult::NotReady) => {}, - Ok(HandshakeResult::Succeeded(protocol)) => { - tracing::warn!(target: LOG_TARGET, ?protocol, "protocol negotiated, inform protocol handler"); - - return self.report_open_substream(d.id, protocol).await; - }, - Err(error) => { - tracing::error!(target: LOG_TARGET, ?error, "failed to negotiate protocol"); - // TODO: close channel - }, - } - - Ok(WebRtcEvent::Noop) - }, - Some(SubstreamState::Open { substream, .. }) => { - // TODO: might be empty message with flags - // TODO: if decoding fails, close the substream - let message = - WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; - let _ = substream.tx.send(message).await; - - Ok(WebRtcEvent::Noop) - }, - } - }, - _ => Err(Error::InvalidState), - } - } - - /// Open outbound substream. - fn open_substream( - &mut self, - protocol: ProtocolName, - fallback_names: Vec, - substream_id: SubstreamId, - permit: Permit, - ) { - let channel_id = self.rtc.direct_api().create_data_channel(ChannelConfig { - label: protocol.to_string(), - ordered: false, - reliability: Default::default(), - negotiated: None, - protocol: protocol.to_string(), - }); - - tracing::trace!( - target: LOG_TARGET, - ?channel_id, - ?substream_id, - ?protocol, - ?fallback_names, - "open data channel" - ); - - self.pending_outbound - .insert(channel_id, (protocol, fallback_names, substream_id, permit)); - } - - /// Run the event loop of a negotiated WebRTC connection. - pub(super) async fn run(mut self) -> crate::Result<()> { - loop { - if !self.rtc.is_alive() { - tracing::debug!( - target: LOG_TARGET, - "`Rtc` is not alive, closing `WebRtcConnection`" - ); - return Ok(()); - } - - let duration = match self.poll_output().await { - Ok(WebRtcEvent::Timeout(timeout)) => { - let timeout = - std::cmp::min(timeout, Instant::now() + Duration::from_millis(100)); - (timeout - Instant::now()).max(Duration::from_millis(1)) - }, - Ok(WebRtcEvent::Noop) => continue, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - "error occurred, closing connection" - ); - self.rtc.disconnect(); - return Ok(()); - }, - }; - - tokio::select! { - message = self.dgram_rx.recv() => match message { - Some(message) => match self.on_input(message).await { - Ok(_) | Err(Error::InputRejected) => {}, - Err(error) => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to handle input"); - return Err(error) - } - } - None => { - tracing::debug!( - target: LOG_TARGET, - source = ?self.peer_address, - "transport shut down, shutting down connection", - ); - return Ok(()); - } - }, - event = self.backend.next_event() => { - let (channel_id, message) = event.ok_or(Error::EssentialTaskClosed)?; - - match self.substreams.get_mut(&channel_id) { - None => { - tracing::debug!(target: LOG_TARGET, "protocol tried to send message over substream that doesn't exist"); - } - Some(SubstreamState::Poisoned) => {}, - Some(SubstreamState::Opening { .. }) => { - tracing::debug!(target: LOG_TARGET, "protocol tried to send message over substream that isn't open"); - } - Some(SubstreamState::Open { .. }) => { - tracing::trace!(target: LOG_TARGET, ?channel_id, ?message, "send message to remote peer"); - - self.rtc - .channel(channel_id) - .ok_or(Error::ChannelDoesntExist)? - .write(true, message.as_ref()) - .map_err(|error| Error::WebRtc(error))?; - } - } - } - event = self.protocol_set.next() => match event { - Some(event) => match event { - ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit } => { - self.open_substream(protocol, fallback_names, substream_id, permit); - } - ProtocolCommand::ForceClose => { - tracing::debug!(target: LOG_TARGET, "force closing connection"); - return Ok(()); - } - } - None => { - tracing::debug!(target: LOG_TARGET, "handle to protocol closed, closing connection"); - return Ok(()); - } - }, - _ = tokio::time::sleep(duration) => {} - } - - // drive time forward in the client - if let Err(error) = self.rtc.handle_input(Input::Timeout(Instant::now())) { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to handle timeout for `Rtc`" - ); - - self.rtc.disconnect(); - return Err(Error::Disconnected); - } - } - } + pub(super) fn new( + rtc: Rtc, + connection_id: ConnectionId, + _noise_channel_id: ChannelId, + id_keypair: Keypair, + protocol_set: ProtocolSet, + peer_address: SocketAddr, + local_address: SocketAddr, + socket: Arc, + dgram_rx: Receiver>, + ) -> WebRtcConnection { + WebRtcConnection { + rtc, + socket, + dgram_rx, + protocol_set, + id_keypair, + peer_address, + local_address, + connection_id, + _noise_channel_id, + state: State::Closed, + substreams: HashMap::new(), + backend: SubstreamBackend::new(), + substream_id: SubstreamId::new(), + pending_outbound: HashMap::new(), + } + } + + pub(super) async fn poll_output(&mut self) -> crate::Result { + match self.rtc.poll_output() { + Ok(output) => self.handle_output(output).await, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + connection_id = ?self.connection_id, + ?error, + "`WebRtcConnection::poll_output()` failed", + ); + return Err(Error::WebRtc(error)); + } + } + } + + /// Handle data received from peer. + pub(super) async fn on_input(&mut self, buffer: Vec) -> crate::Result<()> { + let message = Input::Receive( + Instant::now(), + Receive { + source: self.peer_address, + destination: self.local_address, + contents: buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?, + }, + ); + + match self.rtc.accepts(&message) { + true => self.rtc.handle_input(message).map_err(|error| { + tracing::debug!(target: LOG_TARGET, source = ?self.peer_address, ?error, "failed to handle data"); + Error::InputRejected + }), + false => return Err(Error::InputRejected), + } + } + + async fn handle_output(&mut self, output: Output) -> crate::Result { + match output { + Output::Transmit(transmit) => { + self.socket + .send_to(&transmit.contents, transmit.destination) + .await + .expect("send to succeed"); + Ok(WebRtcEvent::Noop) + } + Output::Timeout(t) => Ok(WebRtcEvent::Timeout(t)), + Output::Event(e) => match e { + Event::IceConnectionStateChange(v) => { + if v == IceConnectionState::Disconnected { + tracing::debug!(target: LOG_TARGET, "ice connection closed"); + return Err(Error::Disconnected); + } + Ok(WebRtcEvent::Noop) + } + Event::ChannelOpen(cid, name) => { + // TODO: remove, report issue to smoldot + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + self.on_channel_open(cid, name).map(|_| WebRtcEvent::Noop) + } + Event::ChannelData(data) => self.on_channel_data(data).await, + Event::ChannelClose(channel_id) => { + // TODO: notify the protocol + tracing::debug!(target: LOG_TARGET, ?channel_id, "channel closed"); + Ok(WebRtcEvent::Noop) + } + Event::Connected => { + match std::mem::replace(&mut self.state, State::Poisoned) { + State::Closed => { + let remote_fingerprint = self.remote_fingerprint(); + let local_fingerprint = self.local_fingerprint(); + + let handshaker = NoiseContext::with_prologue( + &self.id_keypair, + noise_prologue_new(local_fingerprint, remote_fingerprint), + ); + + self.state = State::Opened { handshaker }; + } + state => { + tracing::debug!( + target: LOG_TARGET, + ?state, + "invalid state for connection" + ); + return Err(Error::InvalidState); + } + } + Ok(WebRtcEvent::Noop) + } + event => { + tracing::warn!(target: LOG_TARGET, ?event, "unhandled event"); + Ok(WebRtcEvent::Noop) + } + }, + } + } + + /// Get remote fingerprint to bytes. + fn remote_fingerprint(&mut self) -> Vec { + let fingerprint = self + .rtc + .direct_api() + .remote_dtls_fingerprint() + .clone() + .expect("fingerprint to exist"); + Self::fingerprint_to_bytes(&fingerprint) + } + + /// Get local fingerprint as bytes. + fn local_fingerprint(&mut self) -> Vec { + Self::fingerprint_to_bytes(&self.rtc.direct_api().local_dtls_fingerprint()) + } + + /// Convert `Fingerprint` to bytes. + fn fingerprint_to_bytes(fingerprint: &Fingerprint) -> Vec { + const MULTIHASH_SHA256_CODE: u64 = 0x12; + Multihash::wrap(MULTIHASH_SHA256_CODE, &fingerprint.bytes) + .expect("fingerprint's len to be 32 bytes") + .to_bytes() + } + + fn on_noise_channel_open(&mut self) -> crate::Result<()> { + tracing::trace!(target: LOG_TARGET, "send initial noise handshake"); + + let State::Opened { mut handshaker } = std::mem::replace(&mut self.state, State::Poisoned) + else { + return Err(Error::InvalidState); + }; + + // create first noise handshake and send it to remote peer + let payload = WebRtcMessage::encode(handshaker.first_message(Role::Dialer), None); + + self.rtc + .channel(self._noise_channel_id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, payload.as_slice()) + .map_err(|error| Error::WebRtc(error))?; + + self.state = State::HandshakeSent { handshaker }; + Ok(()) + } + + fn on_channel_open(&mut self, channel_id: ChannelId, name: String) -> crate::Result<()> { + tracing::debug!(target: LOG_TARGET, ?channel_id, channel_name = ?name, "channel opened"); + + if channel_id == self._noise_channel_id { + return self.on_noise_channel_open(); + } + + match self.pending_outbound.remove(&channel_id) { + None => { + tracing::trace!(target: LOG_TARGET, ?channel_id, "remote opened a substream"); + } + Some((protocol, fallback_names, substream_id, permit)) => { + tracing::trace!(target: LOG_TARGET, ?channel_id, "dialer negotiate protocol"); + + let (dialer_state, message) = + DialerState::propose(protocol.clone(), fallback_names)?; + let message = WebRtcMessage::encode(message, None); + + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, message.as_ref()) + .map_err(|error| Error::WebRtc(error))?; + + self.substreams.insert( + channel_id, + SubstreamState::Opening { + protocol, + fallback: None, + substream_id, + dialer_state, + permit, + }, + ); + } + } + + Ok(()) + } + + async fn on_noise_channel_data(&mut self, data: Vec) -> crate::Result { + tracing::trace!(target: LOG_TARGET, "handle noise handshake reply"); + + let State::HandshakeSent { mut handshaker } = + std::mem::replace(&mut self.state, State::Poisoned) + else { + return Err(Error::InvalidState); + }; + + let message = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; + let public_key = handshaker.get_remote_public_key(&message)?; + let remote_peer_id = PeerId::from_public_key(&public_key); + + tracing::trace!( + target: LOG_TARGET, + ?remote_peer_id, + "remote reply parsed successfully" + ); + + // create second noise handshake message and send it to remote + let payload = WebRtcMessage::encode(handshaker.second_message(), None); + + let mut channel = + self.rtc.channel(self._noise_channel_id).ok_or(Error::ChannelDoesntExist)?; + + channel.write(true, payload.as_slice()).map_err(|error| Error::WebRtc(error))?; + + let remote_fingerprint = self + .rtc + .direct_api() + .remote_dtls_fingerprint() + .clone() + .expect("fingerprint to exist") + .bytes; + + const MULTIHASH_SHA256_CODE: u64 = 0x12; + let certificate = Multihash::wrap(MULTIHASH_SHA256_CODE, &remote_fingerprint) + .expect("fingerprint's len to be 32 bytes"); + + let address = Multiaddr::empty() + .with(Protocol::from(self.peer_address.ip())) + .with(Protocol::Udp(self.peer_address.port())) + .with(Protocol::WebRTC) + .with(Protocol::Certhash(certificate)) + .with(Protocol::P2p(PeerId::from(public_key).into())); + + self.protocol_set + .report_connection_established( + remote_peer_id, + Endpoint::listener(address, self.connection_id), + ) + .await?; + + self.state = State::Open { + peer: remote_peer_id, + }; + + Ok(WebRtcEvent::Noop) + } + + /// Report open substream to the protocol. + async fn report_open_substream( + &mut self, + channel_id: ChannelId, + protocol: ProtocolName, + ) -> crate::Result { + // let substream_id = self.substream_id.next(); + // let (mut substream, tx) = self.backend.substream(channel_id); + // let substream: Box = { + // substream.apply_codec(self.protocol_set.protocol_codec(&protocol)); + // Box::new(substream) + // }; + // let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; + + // self.substreams.insert( + // channel_id, + // SubstreamState::Open { + // substream_id, + // substream: SubstreamContext::new(channel_id, tx), + // permit, + // }, + // ); + // TODO: fix + + if let State::Open { peer, .. } = &mut self.state { + // let _ = self + // .protocol_set + // .report_substream_open(*peer, protocol.clone(), Direction::Inbound, substream) + // .await; + todo!(); + } + + Ok(WebRtcEvent::Noop) + } + + /// Negotiate protocol for the channel + async fn listener_negotiate_protocol(&mut self, d: ChannelData) -> crate::Result { + tracing::trace!(target: LOG_TARGET, channel_id = ?d.id, "negotiate protocol for the channel"); + + let payload = WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; + + let (protocol, response) = + listener_negotiate(&mut self.protocol_set.protocols().iter(), payload.into())?; + + let message = WebRtcMessage::encode(response.to_vec(), None); + + self.rtc + .channel(d.id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, message.as_ref()) + .map_err(|error| Error::WebRtc(error))?; + + self.report_open_substream(d.id, protocol).await + + // let substream_id = self.substream_id.next(); + // let (mut substream, tx) = self.backend.substream(d.id); + // let substream: Box = { + // substream.apply_codec(self.protocol_set.protocol_codec(&protocol)); + // Box::new(substream) + // }; + // let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; + + // self.substreams.insert( + // d.id, + // SubstreamState::Open { + // substream_id, + // substream: SubstreamContext::new(d.id, tx), + // permit, + // }, + // ); + + // if let State::Open { peer, .. } = &mut self.state { + // let _ = self + // .protocol_set + // .report_substream_open(*peer, protocol.clone(), Direction::Inbound, substream) + // .await; + // } + // Ok(WebRtcEvent::Noop) + } + + async fn on_channel_data(&mut self, d: ChannelData) -> crate::Result { + match &self.state { + State::HandshakeSent { .. } => self.on_noise_channel_data(d.data).await, + State::Open { .. } => { + match self.substreams.get_mut(&d.id) { + None => match self.listener_negotiate_protocol(d).await { + Ok(_) => { + tracing::debug!(target: LOG_TARGET, "protocol negotiated for the channel"); + + Ok(WebRtcEvent::Noop) + } + Err(error) => { + tracing::debug!(target: LOG_TARGET, ?error, "failed to negotiate protocol"); + + // TODO: close channel + Ok(WebRtcEvent::Noop) + } + }, + Some(SubstreamState::Poisoned) => return Err(Error::ConnectionClosed), + Some(SubstreamState::Opening { + ref mut dialer_state, + .. + }) => { + tracing::info!(target: LOG_TARGET, "try to decode message"); + let message = + WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; + tracing::info!(target: LOG_TARGET, "decoded successfully"); + + match dialer_state.register_response(message) { + Ok(HandshakeResult::NotReady) => {} + Ok(HandshakeResult::Succeeded(protocol)) => { + tracing::warn!(target: LOG_TARGET, ?protocol, "protocol negotiated, inform protocol handler"); + + return self.report_open_substream(d.id, protocol).await; + } + Err(error) => { + tracing::error!(target: LOG_TARGET, ?error, "failed to negotiate protocol"); + // TODO: close channel + } + } + + Ok(WebRtcEvent::Noop) + } + Some(SubstreamState::Open { substream, .. }) => { + // TODO: might be empty message with flags + // TODO: if decoding fails, close the substream + let message = + WebRtcMessage::decode(&d.data)?.payload.ok_or(Error::InvalidData)?; + let _ = substream.tx.send(message).await; + + Ok(WebRtcEvent::Noop) + } + } + } + _ => Err(Error::InvalidState), + } + } + + /// Open outbound substream. + fn open_substream( + &mut self, + protocol: ProtocolName, + fallback_names: Vec, + substream_id: SubstreamId, + permit: Permit, + ) { + let channel_id = self.rtc.direct_api().create_data_channel(ChannelConfig { + label: protocol.to_string(), + ordered: false, + reliability: Default::default(), + negotiated: None, + protocol: protocol.to_string(), + }); + + tracing::trace!( + target: LOG_TARGET, + ?channel_id, + ?substream_id, + ?protocol, + ?fallback_names, + "open data channel" + ); + + self.pending_outbound + .insert(channel_id, (protocol, fallback_names, substream_id, permit)); + } + + /// Run the event loop of a negotiated WebRTC connection. + pub(super) async fn run(mut self) -> crate::Result<()> { + loop { + if !self.rtc.is_alive() { + tracing::debug!( + target: LOG_TARGET, + "`Rtc` is not alive, closing `WebRtcConnection`" + ); + return Ok(()); + } + + let duration = match self.poll_output().await { + Ok(WebRtcEvent::Timeout(timeout)) => { + let timeout = + std::cmp::min(timeout, Instant::now() + Duration::from_millis(100)); + (timeout - Instant::now()).max(Duration::from_millis(1)) + } + Ok(WebRtcEvent::Noop) => continue, + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + "error occurred, closing connection" + ); + self.rtc.disconnect(); + return Ok(()); + } + }; + + tokio::select! { + message = self.dgram_rx.recv() => match message { + Some(message) => match self.on_input(message).await { + Ok(_) | Err(Error::InputRejected) => {}, + Err(error) => { + tracing::debug!(target: LOG_TARGET, ?error, "failed to handle input"); + return Err(error) + } + } + None => { + tracing::debug!( + target: LOG_TARGET, + source = ?self.peer_address, + "transport shut down, shutting down connection", + ); + return Ok(()); + } + }, + event = self.backend.next_event() => { + let (channel_id, message) = event.ok_or(Error::EssentialTaskClosed)?; + + match self.substreams.get_mut(&channel_id) { + None => { + tracing::debug!(target: LOG_TARGET, "protocol tried to send message over substream that doesn't exist"); + } + Some(SubstreamState::Poisoned) => {}, + Some(SubstreamState::Opening { .. }) => { + tracing::debug!(target: LOG_TARGET, "protocol tried to send message over substream that isn't open"); + } + Some(SubstreamState::Open { .. }) => { + tracing::trace!(target: LOG_TARGET, ?channel_id, ?message, "send message to remote peer"); + + self.rtc + .channel(channel_id) + .ok_or(Error::ChannelDoesntExist)? + .write(true, message.as_ref()) + .map_err(|error| Error::WebRtc(error))?; + } + } + } + event = self.protocol_set.next() => match event { + Some(event) => match event { + ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit } => { + self.open_substream(protocol, fallback_names, substream_id, permit); + } + ProtocolCommand::ForceClose => { + tracing::debug!(target: LOG_TARGET, "force closing connection"); + return Ok(()); + } + } + None => { + tracing::debug!(target: LOG_TARGET, "handle to protocol closed, closing connection"); + return Ok(()); + } + }, + _ = tokio::time::sleep(duration) => {} + } + + // drive time forward in the client + if let Err(error) = self.rtc.handle_input(Input::Timeout(Instant::now())) { + tracing::debug!( + target: LOG_TARGET, + ?error, + "failed to handle timeout for `Rtc`" + ); + + self.rtc.disconnect(); + return Err(Error::Disconnected); + } + } + } } diff --git a/src/transport/webrtc/mod.rs b/src/transport/webrtc/mod.rs index e6f84fe9..5ee3616f 100644 --- a/src/transport/webrtc/mod.rs +++ b/src/transport/webrtc/mod.rs @@ -23,38 +23,38 @@ #![allow(unused)] use crate::{ - error::{AddressError, Error}, - transport::{ - manager::TransportHandle, - webrtc::{config::Config, connection::WebRtcConnection}, - Transport, TransportBuilder, TransportEvent, - }, - types::ConnectionId, - PeerId, + error::{AddressError, Error}, + transport::{ + manager::TransportHandle, + webrtc::{config::Config, connection::WebRtcConnection}, + Transport, TransportBuilder, TransportEvent, + }, + types::ConnectionId, + PeerId, }; use futures::{Stream, StreamExt}; use multiaddr::{multihash::Multihash, Multiaddr, Protocol}; use socket2::{Domain, Socket, Type}; use str0m::{ - change::{DtlsCert, IceCreds}, - channel::{ChannelConfig, ChannelId}, - net::{DatagramRecv, Receive}, - Candidate, Input, Rtc, + change::{DtlsCert, IceCreds}, + channel::{ChannelConfig, ChannelId}, + net::{DatagramRecv, Receive}, + Candidate, Input, Rtc, }; use tokio::{ - io::ReadBuf, - net::UdpSocket, - sync::mpsc::{channel, Sender}, + io::ReadBuf, + net::UdpSocket, + sync::mpsc::{channel, Sender}, }; use std::{ - collections::HashMap, - net::{IpAddr, SocketAddr}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - time::Instant, + collections::HashMap, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Instant, }; pub mod config; @@ -64,13 +64,13 @@ mod substream; mod util; mod schema { - pub(super) mod webrtc { - include!(concat!(env!("OUT_DIR"), "/webrtc.rs")); - } + pub(super) mod webrtc { + include!(concat!(env!("OUT_DIR"), "/webrtc.rs")); + } - pub(super) mod noise { - include!(concat!(env!("OUT_DIR"), "/noise.rs")); - } + pub(super) mod noise { + include!(concat!(env!("OUT_DIR"), "/noise.rs")); + } } /// Logging target for the file. @@ -82,340 +82,342 @@ const REMOTE_FINGERPRINT: &str = /// WebRTC transport. pub(crate) struct WebRtcTransport { - /// Transport context. - context: TransportHandle, + /// Transport context. + context: TransportHandle, - /// UDP socket. - socket: Arc, + /// UDP socket. + socket: Arc, - /// DTLS certificate. - dtls_cert: DtlsCert, + /// DTLS certificate. + dtls_cert: DtlsCert, - /// Assigned listen addresss. - listen_address: SocketAddr, + /// Assigned listen addresss. + listen_address: SocketAddr, - /// Connected peers. - peers: HashMap>>, + /// Connected peers. + peers: HashMap>>, } impl WebRtcTransport { - /// Extract socket address and `PeerId`, if found, from `address`. - fn get_socket_address(address: &Multiaddr) -> crate::Result<(SocketAddr, Option)> { - tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); - - let mut iter = address.iter(); - let socket_address = match iter.next() { - Some(Protocol::Ip6(address)) => match iter.next() { - Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V6(address), port), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Upd`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - Some(Protocol::Ip4(address)) => match iter.next() { - Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V4(address), port), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Udp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - protocol => { - tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - match iter.next() { - Some(Protocol::WebRTC) => {}, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `WebRTC`" - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - } - - let maybe_peer = match iter.next() { - Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), - None => None, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `P2p` or `None`" - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - Ok((socket_address, maybe_peer)) - } - - /// Create RTC client and open channel for Noise handshake. - fn make_rtc_client( - &self, - ufrag: &str, - pass: &str, - source: SocketAddr, - destination: SocketAddr, - ) -> (Rtc, ChannelId) { - let mut rtc = Rtc::builder() - .set_ice_lite(true) - .set_dtls_cert(self.dtls_cert.clone()) - .set_fingerprint_verification(false) - .build(); - rtc.add_local_candidate(Candidate::host(destination).unwrap()); - rtc.add_remote_candidate(Candidate::host(source).unwrap()); - rtc.direct_api() - .set_remote_fingerprint(REMOTE_FINGERPRINT.parse().expect("parse() to succeed")); - rtc.direct_api().set_remote_ice_credentials(IceCreds { - ufrag: ufrag.to_owned(), - pass: pass.to_owned(), - }); - rtc.direct_api() - .set_local_ice_credentials(IceCreds { ufrag: ufrag.to_owned(), pass: pass.to_owned() }); - rtc.direct_api().set_ice_controlling(false); - rtc.direct_api().start_dtls(false).unwrap(); - rtc.direct_api().start_sctp(false); - - let noise_channel_id = rtc.direct_api().create_data_channel(ChannelConfig { - label: "noise".to_string(), - ordered: false, - reliability: Default::default(), - negotiated: Some(0), - protocol: "".to_string(), - }); - - (rtc, noise_channel_id) - } - - /// Handle socket input. - fn on_socket_input(&mut self, source: SocketAddr, buffer: Vec) -> crate::Result<()> { - // if the `Rtc` object already exists for `souce`, pass the message directly to that - // connection. - if let Some(tx) = self.peers.get_mut(&source) { - // TODO: implement properly - match tx.try_send(buffer) { - Ok(()) => return Ok(()), - Err(error) => { - tracing::warn!(target: LOG_TARGET, ?error, "failed to send datagram to connection"); - return Ok(()); - }, - } - } - - // if the peer doesn't exist, decode the message and expect to receive `Stun` - // so that a new connection can be initialized - let contents: DatagramRecv = - buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?; - - match contents { - DatagramRecv::Stun(message) => { - if let Some((ufrag, pass)) = message.split_username() { - tracing::debug!( - target: LOG_TARGET, - ?source, - ?ufrag, - ?pass, - "received stun message" - ); - - // create new `Rtc` object for the peer and give it the received STUN message - let (mut rtc, noise_channel_id) = self.make_rtc_client( - ufrag, - pass, - source, - self.socket.local_addr().unwrap(), - ); - - rtc.handle_input(Input::Receive( - Instant::now(), - Receive { - source, - destination: self.socket.local_addr().unwrap(), - contents: DatagramRecv::Stun(message.clone()), - }, - )) - .expect("client to handle input successfully"); - - let (tx, rx) = channel(64); - let connection_id = self.context.next_connection_id(); - - let connection = WebRtcConnection::new( - rtc, - connection_id, - noise_channel_id, - self.context.keypair.clone(), - self.context.protocol_set(connection_id), - source, - self.listen_address, - Arc::clone(&self.socket), - rx, - ); - - self.context.executor.run(Box::pin(async move { - let _ = connection.run().await; - })); - self.peers.insert(source, tx); - } - }, - message => { - tracing::error!( - target: LOG_TARGET, - ?source, - ?message, - "received unexpected message for a connection that doesn't eixst" - ); - }, - } - - Ok(()) - } + /// Extract socket address and `PeerId`, if found, from `address`. + fn get_socket_address(address: &Multiaddr) -> crate::Result<(SocketAddr, Option)> { + tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); + + let mut iter = address.iter(); + let socket_address = match iter.next() { + Some(Protocol::Ip6(address)) => match iter.next() { + Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V6(address), port), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Upd`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + Some(Protocol::Ip4(address)) => match iter.next() { + Some(Protocol::Udp(port)) => SocketAddr::new(IpAddr::V4(address), port), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Udp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + protocol => { + tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + match iter.next() { + Some(Protocol::WebRTC) => {} + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `WebRTC`" + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + } + + let maybe_peer = match iter.next() { + Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), + None => None, + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `P2p` or `None`" + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + Ok((socket_address, maybe_peer)) + } + + /// Create RTC client and open channel for Noise handshake. + fn make_rtc_client( + &self, + ufrag: &str, + pass: &str, + source: SocketAddr, + destination: SocketAddr, + ) -> (Rtc, ChannelId) { + let mut rtc = Rtc::builder() + .set_ice_lite(true) + .set_dtls_cert(self.dtls_cert.clone()) + .set_fingerprint_verification(false) + .build(); + rtc.add_local_candidate(Candidate::host(destination).unwrap()); + rtc.add_remote_candidate(Candidate::host(source).unwrap()); + rtc.direct_api() + .set_remote_fingerprint(REMOTE_FINGERPRINT.parse().expect("parse() to succeed")); + rtc.direct_api().set_remote_ice_credentials(IceCreds { + ufrag: ufrag.to_owned(), + pass: pass.to_owned(), + }); + rtc.direct_api().set_local_ice_credentials(IceCreds { + ufrag: ufrag.to_owned(), + pass: pass.to_owned(), + }); + rtc.direct_api().set_ice_controlling(false); + rtc.direct_api().start_dtls(false).unwrap(); + rtc.direct_api().start_sctp(false); + + let noise_channel_id = rtc.direct_api().create_data_channel(ChannelConfig { + label: "noise".to_string(), + ordered: false, + reliability: Default::default(), + negotiated: Some(0), + protocol: "".to_string(), + }); + + (rtc, noise_channel_id) + } + + /// Handle socket input. + fn on_socket_input(&mut self, source: SocketAddr, buffer: Vec) -> crate::Result<()> { + // if the `Rtc` object already exists for `souce`, pass the message directly to that + // connection. + if let Some(tx) = self.peers.get_mut(&source) { + // TODO: implement properly + match tx.try_send(buffer) { + Ok(()) => return Ok(()), + Err(error) => { + tracing::warn!(target: LOG_TARGET, ?error, "failed to send datagram to connection"); + return Ok(()); + } + } + } + + // if the peer doesn't exist, decode the message and expect to receive `Stun` + // so that a new connection can be initialized + let contents: DatagramRecv = + buffer.as_slice().try_into().map_err(|_| Error::InvalidData)?; + + match contents { + DatagramRecv::Stun(message) => { + if let Some((ufrag, pass)) = message.split_username() { + tracing::debug!( + target: LOG_TARGET, + ?source, + ?ufrag, + ?pass, + "received stun message" + ); + + // create new `Rtc` object for the peer and give it the received STUN message + let (mut rtc, noise_channel_id) = self.make_rtc_client( + ufrag, + pass, + source, + self.socket.local_addr().unwrap(), + ); + + rtc.handle_input(Input::Receive( + Instant::now(), + Receive { + source, + destination: self.socket.local_addr().unwrap(), + contents: DatagramRecv::Stun(message.clone()), + }, + )) + .expect("client to handle input successfully"); + + let (tx, rx) = channel(64); + let connection_id = self.context.next_connection_id(); + + let connection = WebRtcConnection::new( + rtc, + connection_id, + noise_channel_id, + self.context.keypair.clone(), + self.context.protocol_set(connection_id), + source, + self.listen_address, + Arc::clone(&self.socket), + rx, + ); + + self.context.executor.run(Box::pin(async move { + let _ = connection.run().await; + })); + self.peers.insert(source, tx); + } + } + message => { + tracing::error!( + target: LOG_TARGET, + ?source, + ?message, + "received unexpected message for a connection that doesn't eixst" + ); + } + } + + Ok(()) + } } impl TransportBuilder for WebRtcTransport { - type Config = Config; - type Transport = WebRtcTransport; - - /// Create new [`Transport`] object. - fn new(context: TransportHandle, config: Self::Config) -> crate::Result<(Self, Vec)> - where - Self: Sized, - { - tracing::info!( - target: LOG_TARGET, - listen_addresses = ?config.listen_addresses, - "start webrtc transport", - ); - - let (listen_address, _) = Self::get_socket_address(&config.listen_addresses[0])?; - let socket = match listen_address.is_ipv4() { - true => { - let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?; - socket.bind(&listen_address.into())?; - socket - }, - false => { - let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(socket2::Protocol::UDP))?; - socket.set_only_v6(true)?; - socket.bind(&listen_address.into())?; - socket - }, - }; - socket.listen(1024)?; - socket.set_reuse_address(true)?; - socket.set_nonblocking(true)?; - #[cfg(unix)] - socket.set_reuse_port(true)?; - - let socket = UdpSocket::from_std(socket.into())?; - let listen_address = socket.local_addr()?; - let dtls_cert = DtlsCert::new(); - - let listen_multi_addresses = { - let fingerprint = dtls_cert.fingerprint().bytes; - - const MULTIHASH_SHA256_CODE: u64 = 0x12; - let certificate = Multihash::wrap(MULTIHASH_SHA256_CODE, &fingerprint) - .expect("fingerprint's len to be 32 bytes"); - - vec![Multiaddr::empty() - .with(Protocol::from(listen_address.ip())) - .with(Protocol::Udp(listen_address.port())) - .with(Protocol::WebRTC) - .with(Protocol::Certhash(certificate))] - }; - - Ok(( - Self { - context, - dtls_cert, - listen_address, - peers: HashMap::new(), - socket: Arc::new(socket), - }, - listen_multi_addresses, - )) - } + type Config = Config; + type Transport = WebRtcTransport; + + /// Create new [`Transport`] object. + fn new(context: TransportHandle, config: Self::Config) -> crate::Result<(Self, Vec)> + where + Self: Sized, + { + tracing::info!( + target: LOG_TARGET, + listen_addresses = ?config.listen_addresses, + "start webrtc transport", + ); + + let (listen_address, _) = Self::get_socket_address(&config.listen_addresses[0])?; + let socket = match listen_address.is_ipv4() { + true => { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?; + socket.bind(&listen_address.into())?; + socket + } + false => { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(socket2::Protocol::UDP))?; + socket.set_only_v6(true)?; + socket.bind(&listen_address.into())?; + socket + } + }; + socket.listen(1024)?; + socket.set_reuse_address(true)?; + socket.set_nonblocking(true)?; + #[cfg(unix)] + socket.set_reuse_port(true)?; + + let socket = UdpSocket::from_std(socket.into())?; + let listen_address = socket.local_addr()?; + let dtls_cert = DtlsCert::new(); + + let listen_multi_addresses = { + let fingerprint = dtls_cert.fingerprint().bytes; + + const MULTIHASH_SHA256_CODE: u64 = 0x12; + let certificate = Multihash::wrap(MULTIHASH_SHA256_CODE, &fingerprint) + .expect("fingerprint's len to be 32 bytes"); + + vec![Multiaddr::empty() + .with(Protocol::from(listen_address.ip())) + .with(Protocol::Udp(listen_address.port())) + .with(Protocol::WebRTC) + .with(Protocol::Certhash(certificate))] + }; + + Ok(( + Self { + context, + dtls_cert, + listen_address, + peers: HashMap::new(), + socket: Arc::new(socket), + }, + listen_multi_addresses, + )) + } } impl Transport for WebRtcTransport { - fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { - tracing::warn!( - target: LOG_TARGET, - ?connection_id, - ?address, - "webrtc cannot dial", - ); - - Err(Error::NotSupported(format!("webrtc cannot dial peers"))) - } - - fn accept(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { - Ok(()) - } - - fn reject(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { - Ok(()) - } - - fn open( - &mut self, - _connection_id: ConnectionId, - _addresses: Vec, - ) -> crate::Result<()> { - Ok(()) - } - - fn negotiate(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { - Ok(()) - } - - /// Cancel opening connections. - fn cancel(&mut self, _connection_id: ConnectionId) {} + fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { + tracing::warn!( + target: LOG_TARGET, + ?connection_id, + ?address, + "webrtc cannot dial", + ); + + Err(Error::NotSupported(format!("webrtc cannot dial peers"))) + } + + fn accept(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn reject(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn open( + &mut self, + _connection_id: ConnectionId, + _addresses: Vec, + ) -> crate::Result<()> { + Ok(()) + } + + fn negotiate(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + /// Cancel opening connections. + fn cancel(&mut self, _connection_id: ConnectionId) {} } impl Stream for WebRtcTransport { - type Item = TransportEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // TODO: optimizations - let mut buf = vec![0u8; 16384]; - let mut read_buf = ReadBuf::new(&mut buf); - - match self.socket.poll_recv_from(cx, &mut read_buf) { - Poll::Pending => {}, - Poll::Ready(Ok(source)) => { - let nread = read_buf.filled().len(); - buf.truncate(nread); - - if let Err(error) = self.on_socket_input(source, buf) { - tracing::error!(target: LOG_TARGET, ?error, "failed to handle input"); - } - }, - Poll::Ready(Err(error)) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to read from webrtc socket", - ); - - return Poll::Ready(None); - }, - } - - Poll::Pending - } + type Item = TransportEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // TODO: optimizations + let mut buf = vec![0u8; 16384]; + let mut read_buf = ReadBuf::new(&mut buf); + + match self.socket.poll_recv_from(cx, &mut read_buf) { + Poll::Pending => {} + Poll::Ready(Ok(source)) => { + let nread = read_buf.filled().len(); + buf.truncate(nread); + + if let Err(error) = self.on_socket_input(source, buf) { + tracing::error!(target: LOG_TARGET, ?error, "failed to handle input"); + } + } + Poll::Ready(Err(error)) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + "failed to read from webrtc socket", + ); + + return Poll::Ready(None); + } + } + + Poll::Pending + } } // TODO: remove @@ -423,9 +425,9 @@ impl Stream for WebRtcTransport { #[allow(clippy::large_enum_variant)] #[derive(Debug)] enum WebRtcEvent { - /// When we have nothing to propagate. - Noop, + /// When we have nothing to propagate. + Noop, - /// Poll client has reached timeout. - Timeout(Instant), + /// Poll client has reached timeout. + Timeout(Instant), } diff --git a/src/transport/webrtc/substream.rs b/src/transport/webrtc/substream.rs index 83b697a7..dad62cd4 100644 --- a/src/transport/webrtc/substream.rs +++ b/src/transport/webrtc/substream.rs @@ -21,8 +21,8 @@ //! Channel-backed substream. use crate::{ - codec::{identity::Identity, unsigned_varint::UnsignedVarint, ProtocolCodec}, - error::Error, + codec::{identity::Identity, unsigned_varint::UnsignedVarint, ProtocolCodec}, + error::Error, }; use bytes::BytesMut; @@ -33,8 +33,8 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::PollSender; use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; // TODO: use substream id @@ -42,103 +42,108 @@ use std::{ /// Channel-backed substream. #[derive(Debug)] pub struct Substream { - /// Channel ID. - id: ChannelId, + /// Channel ID. + id: ChannelId, - /// TX channel for sending messages to transport. - tx: PollSender<(ChannelId, Vec)>, + /// TX channel for sending messages to transport. + tx: PollSender<(ChannelId, Vec)>, - /// RX channel for receiving messages from transport. - rx: ReceiverStream>, + /// RX channel for receiving messages from transport. + rx: ReceiverStream>, - /// Protocol codec. - codec: Option, + /// Protocol codec. + codec: Option, } impl Substream { - /// Create new [`Substream`]. - pub fn new(id: ChannelId, tx: Sender<(ChannelId, Vec)>) -> (Self, Sender>) { - let (to_protocol, rx) = channel(64); - - ( - Self { id, codec: None, tx: PollSender::new(tx), rx: ReceiverStream::new(rx) }, - to_protocol, - ) - } - - /// Apply codec for the substream. - pub fn apply_codec(&mut self, codec: ProtocolCodec) { - self.codec = Some(codec); - } + /// Create new [`Substream`]. + pub fn new(id: ChannelId, tx: Sender<(ChannelId, Vec)>) -> (Self, Sender>) { + let (to_protocol, rx) = channel(64); + + ( + Self { + id, + codec: None, + tx: PollSender::new(tx), + rx: ReceiverStream::new(rx), + }, + to_protocol, + ) + } + + /// Apply codec for the substream. + pub fn apply_codec(&mut self, codec: ProtocolCodec) { + self.codec = Some(codec); + } } impl Sink for Substream { - type Error = Error; - - fn poll_ready<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { - let pinned = Pin::new(&mut self.tx); - pinned.poll_ready(cx).map_err(|_| Error::Unknown) - } - - fn start_send(mut self: Pin<&mut Self>, item: bytes::Bytes) -> Result<(), Error> { - let item: Vec = match self.codec.as_ref().expect("codec to exist") { - ProtocolCodec::Identity(_) => Identity::encode(item)?.into(), - ProtocolCodec::UnsignedVarint(_) => UnsignedVarint::encode(item)?.into(), - ProtocolCodec::Unspecified => unreachable!(), // TODO: may not be correct - }; - let id = self.id; - - Pin::new(&mut self.tx).start_send((id, item)).map_err(|_| Error::Unknown) - } - - fn poll_flush<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { - Pin::new(&mut self.tx).poll_flush(cx).map_err(|_| Error::Unknown) - } - - fn poll_close<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { - Pin::new(&mut self.tx).poll_close(cx).map_err(|_| Error::Unknown) - } + type Error = Error; + + fn poll_ready<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { + let pinned = Pin::new(&mut self.tx); + pinned.poll_ready(cx).map_err(|_| Error::Unknown) + } + + fn start_send(mut self: Pin<&mut Self>, item: bytes::Bytes) -> Result<(), Error> { + let item: Vec = match self.codec.as_ref().expect("codec to exist") { + ProtocolCodec::Identity(_) => Identity::encode(item)?.into(), + ProtocolCodec::UnsignedVarint(_) => UnsignedVarint::encode(item)?.into(), + ProtocolCodec::Unspecified => unreachable!(), // TODO: may not be correct + }; + let id = self.id; + + Pin::new(&mut self.tx).start_send((id, item)).map_err(|_| Error::Unknown) + } + + fn poll_flush<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { + Pin::new(&mut self.tx).poll_flush(cx).map_err(|_| Error::Unknown) + } + + fn poll_close<'a>(mut self: Pin<&mut Self>, cx: &mut Context<'a>) -> Poll> { + Pin::new(&mut self.tx).poll_close(cx).map_err(|_| Error::Unknown) + } } impl Stream for Substream { - type Item = crate::Result; - - fn poll_next<'a>( - mut self: Pin<&mut Self>, - cx: &mut Context<'a>, - ) -> Poll>> { - match Pin::new(&mut self.rx).poll_next(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(value)) => Poll::Ready(Some(Ok(BytesMut::from(value.as_slice())))), - } - } + type Item = crate::Result; + + fn poll_next<'a>( + mut self: Pin<&mut Self>, + cx: &mut Context<'a>, + ) -> Poll>> { + match Pin::new(&mut self.rx).poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(value)) => Poll::Ready(Some(Ok(BytesMut::from(value.as_slice())))), + } + } } // TODO: rename? pub struct SubstreamBackend { - /// TX channel for creating new [`Substream`] objects. - tx: Sender<(ChannelId, Vec)>, + /// TX channel for creating new [`Substream`] objects. + tx: Sender<(ChannelId, Vec)>, - /// RX channel for receiving messages from protocols. - rx: Receiver<(ChannelId, Vec)>, + /// RX channel for receiving messages from protocols. + rx: Receiver<(ChannelId, Vec)>, } impl SubstreamBackend { - /// Create new [`SubstreamBackend`]. - pub fn new() -> Self { - let (tx, rx) = channel(1024); - - Self { tx, rx } - } - - /// Create new substream. - pub fn substream(&mut self, id: ChannelId) -> (Substream, Sender>) { - Substream::new(id, self.tx.clone()) - } - - /// Poll next event. - pub async fn next_event(&mut self) -> Option<(ChannelId, Vec)> { - self.rx.recv().await - } + /// Create new [`SubstreamBackend`]. + pub fn new() -> Self { + let (tx, rx) = channel(1024); + + Self { tx, rx } + } + + /// Create new substream. + pub fn substream(&mut self, id: ChannelId) -> (Substream, Sender>) { + Substream::new(id, self.tx.clone()) + } + + /// Poll next event. + pub async fn next_event(&mut self) -> Option<(ChannelId, Vec)> { + self.rx.recv().await + } } diff --git a/src/transport/webrtc/util.rs b/src/transport/webrtc/util.rs index f30d87d7..e985f4ae 100644 --- a/src/transport/webrtc/util.rs +++ b/src/transport/webrtc/util.rs @@ -28,90 +28,95 @@ use tokio_util::codec::{Decoder, Encoder}; /// Substream context. #[derive(Debug)] pub struct SubstreamContext { - /// `str0m` channel id. - pub channel_id: ChannelId, + /// `str0m` channel id. + pub channel_id: ChannelId, - /// TX channel for sending messages to the protocol. - pub tx: Sender>, + /// TX channel for sending messages to the protocol. + pub tx: Sender>, } impl SubstreamContext { - /// Create new [`SubstreamContext`]. - pub fn new(channel_id: ChannelId, tx: Sender>) -> Self { - Self { channel_id, tx } - } + /// Create new [`SubstreamContext`]. + pub fn new(channel_id: ChannelId, tx: Sender>) -> Self { + Self { channel_id, tx } + } } /// WebRTC mesage. #[derive(Debug)] pub struct WebRtcMessage { - /// Payload. - pub payload: Option>, + /// Payload. + pub payload: Option>, - // Flags. - pub flags: Option, + // Flags. + pub flags: Option, } impl WebRtcMessage { - /// Encode WebRTC message. - pub fn encode(payload: Vec, flag: Option) -> Vec { - let protobuf_payload = - schema::webrtc::Message { message: (!payload.is_empty()).then_some(payload), flag }; - let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); - protobuf_payload - .encode(&mut payload) - .expect("Vec to provide needed capacity"); - - let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); - // TODO: set correct size - let mut codec = UnsignedVarint::new(None); - let _result = codec.encode(payload.into(), &mut out_buf); - - out_buf.into() - } - - /// Decode payload into [`WebRtcMessage`]. - pub fn decode(payload: &[u8]) -> crate::Result { - // TODO: set correct size - let mut codec = UnsignedVarint::new(None); - let mut data = bytes::BytesMut::from(payload); - let result = codec.decode(&mut data)?.ok_or(Error::InvalidData)?; - - match schema::webrtc::Message::decode(result) { - Ok(message) => Ok(Self { payload: message.message, flags: message.flag }), - Err(_) => return Err(Error::InvalidData), - } - } + /// Encode WebRTC message. + pub fn encode(payload: Vec, flag: Option) -> Vec { + let protobuf_payload = schema::webrtc::Message { + message: (!payload.is_empty()).then_some(payload), + flag, + }; + let mut payload = Vec::with_capacity(protobuf_payload.encoded_len()); + protobuf_payload + .encode(&mut payload) + .expect("Vec to provide needed capacity"); + + let mut out_buf = bytes::BytesMut::with_capacity(payload.len() + 4); + // TODO: set correct size + let mut codec = UnsignedVarint::new(None); + let _result = codec.encode(payload.into(), &mut out_buf); + + out_buf.into() + } + + /// Decode payload into [`WebRtcMessage`]. + pub fn decode(payload: &[u8]) -> crate::Result { + // TODO: set correct size + let mut codec = UnsignedVarint::new(None); + let mut data = bytes::BytesMut::from(payload); + let result = codec.decode(&mut data)?.ok_or(Error::InvalidData)?; + + match schema::webrtc::Message::decode(result) { + Ok(message) => Ok(Self { + payload: message.message, + flags: message.flag, + }), + Err(_) => return Err(Error::InvalidData), + } + } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn with_payload_no_flags() { - let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), None); - let decoded = WebRtcMessage::decode(&message).unwrap(); - - assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); - assert_eq!(decoded.flags, None); - } - - #[test] - fn with_payload_and_flags() { - let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), Some(1i32)); - let decoded = WebRtcMessage::decode(&message).unwrap(); - - assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); - assert_eq!(decoded.flags, Some(1i32)); - } - - #[test] - fn no_payload_with_flags() { - let message = WebRtcMessage::encode(vec![], Some(2i32)); - let decoded = WebRtcMessage::decode(&message).unwrap(); - - assert_eq!(decoded.payload, None); - assert_eq!(decoded.flags, Some(2i32)); - } + use super::*; + + #[test] + fn with_payload_no_flags() { + let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), None); + let decoded = WebRtcMessage::decode(&message).unwrap(); + + assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); + assert_eq!(decoded.flags, None); + } + + #[test] + fn with_payload_and_flags() { + let message = WebRtcMessage::encode("Hello, world!".as_bytes().to_vec(), Some(1i32)); + let decoded = WebRtcMessage::decode(&message).unwrap(); + + assert_eq!(decoded.payload, Some("Hello, world!".as_bytes().to_vec())); + assert_eq!(decoded.flags, Some(1i32)); + } + + #[test] + fn no_payload_with_flags() { + let message = WebRtcMessage::encode(vec![], Some(2i32)); + let decoded = WebRtcMessage::decode(&message).unwrap(); + + assert_eq!(decoded.payload, None); + assert_eq!(decoded.flags, Some(2i32)); + } } diff --git a/src/transport/websocket/config.rs b/src/transport/websocket/config.rs index d67a6c52..1ec113a6 100644 --- a/src/transport/websocket/config.rs +++ b/src/transport/websocket/config.rs @@ -21,75 +21,75 @@ //! WebSocket transport configuration. use crate::{ - crypto::noise::{MAX_READ_AHEAD_FACTOR, MAX_WRITE_BUFFER_SIZE}, - transport::{CONNECTION_OPEN_TIMEOUT, SUBSTREAM_OPEN_TIMEOUT}, + crypto::noise::{MAX_READ_AHEAD_FACTOR, MAX_WRITE_BUFFER_SIZE}, + transport::{CONNECTION_OPEN_TIMEOUT, SUBSTREAM_OPEN_TIMEOUT}, }; /// WebSocket transport configuration. #[derive(Debug)] pub struct Config { - /// Listen address address for the transport. - /// - /// Default listen addreses are ["/ip4/0.0.0.0/tcp/0/ws", "/ip6/::/tcp/0/ws"]. - pub listen_addresses: Vec, + /// Listen address address for the transport. + /// + /// Default listen addreses are ["/ip4/0.0.0.0/tcp/0/ws", "/ip6/::/tcp/0/ws"]. + pub listen_addresses: Vec, - /// Whether to set `SO_REUSEPORT` and bind a socket to the listen address port for outbound - /// connections. - /// - /// Note that `SO_REUSEADDR` is always set on listening sockets. - /// - /// Defaults to `true`. - pub reuse_port: bool, + /// Whether to set `SO_REUSEPORT` and bind a socket to the listen address port for outbound + /// connections. + /// + /// Note that `SO_REUSEADDR` is always set on listening sockets. + /// + /// Defaults to `true`. + pub reuse_port: bool, - /// Yamux configuration. - pub yamux_config: crate::yamux::Config, + /// Yamux configuration. + pub yamux_config: crate::yamux::Config, - /// Noise read-ahead frame count. - /// - /// Specifies how many Noise frames are read per call to the underlying socket. - /// - /// By default this is configured to `5` so each call to the underlying socket can read up - /// to `5` Noise frame per call. Fewer frames may be read if there isn't enough data in the - /// socket. Each Noise frame is `65 KB` so the default setting allocates `65 KB * 5 = 325 KB` - /// per connection. - pub noise_read_ahead_frame_count: usize, + /// Noise read-ahead frame count. + /// + /// Specifies how many Noise frames are read per call to the underlying socket. + /// + /// By default this is configured to `5` so each call to the underlying socket can read up + /// to `5` Noise frame per call. Fewer frames may be read if there isn't enough data in the + /// socket. Each Noise frame is `65 KB` so the default setting allocates `65 KB * 5 = 325 KB` + /// per connection. + pub noise_read_ahead_frame_count: usize, - /// Noise write buffer size. - /// - /// Specifes how many Noise frames are tried to be coalesced into a single system call. - /// By default the value is set to `2` which means that the `NoiseSocket` will allocate - /// `130 KB` for each outgoing connection. - /// - /// The write buffer size is separate from the read-ahead frame count so by default - /// the Noise code will allocate `2 * 65 KB + 5 * 65 KB = 455 KB` per connection. - pub noise_write_buffer_size: usize, + /// Noise write buffer size. + /// + /// Specifes how many Noise frames are tried to be coalesced into a single system call. + /// By default the value is set to `2` which means that the `NoiseSocket` will allocate + /// `130 KB` for each outgoing connection. + /// + /// The write buffer size is separate from the read-ahead frame count so by default + /// the Noise code will allocate `2 * 65 KB + 5 * 65 KB = 455 KB` per connection. + pub noise_write_buffer_size: usize, - /// Connection open timeout. - /// - /// How long should litep2p wait for a connection to be opend before the host - /// is deemed unreachable. - pub connection_open_timeout: std::time::Duration, + /// Connection open timeout. + /// + /// How long should litep2p wait for a connection to be opend before the host + /// is deemed unreachable. + pub connection_open_timeout: std::time::Duration, - /// Substream open timeout. - /// - /// How long should litep2p wait for a substream to be opened before considering - /// the substream rejected. - pub substream_open_timeout: std::time::Duration, + /// Substream open timeout. + /// + /// How long should litep2p wait for a substream to be opened before considering + /// the substream rejected. + pub substream_open_timeout: std::time::Duration, } impl Default for Config { - fn default() -> Self { - Self { - listen_addresses: vec![ - "/ip4/0.0.0.0/tcp/0/ws".parse().expect("valid address"), - "/ip6/::/tcp/0/ws".parse().expect("valid address"), - ], - reuse_port: true, - yamux_config: Default::default(), - noise_read_ahead_frame_count: MAX_READ_AHEAD_FACTOR, - noise_write_buffer_size: MAX_WRITE_BUFFER_SIZE, - connection_open_timeout: CONNECTION_OPEN_TIMEOUT, - substream_open_timeout: SUBSTREAM_OPEN_TIMEOUT, - } - } + fn default() -> Self { + Self { + listen_addresses: vec![ + "/ip4/0.0.0.0/tcp/0/ws".parse().expect("valid address"), + "/ip6/::/tcp/0/ws".parse().expect("valid address"), + ], + reuse_port: true, + yamux_config: Default::default(), + noise_read_ahead_frame_count: MAX_READ_AHEAD_FACTOR, + noise_write_buffer_size: MAX_WRITE_BUFFER_SIZE, + connection_open_timeout: CONNECTION_OPEN_TIMEOUT, + substream_open_timeout: SUBSTREAM_OPEN_TIMEOUT, + } + } } diff --git a/src/transport/websocket/connection.rs b/src/transport/websocket/connection.rs index a9ed5fc2..19a3e14f 100644 --- a/src/transport/websocket/connection.rs +++ b/src/transport/websocket/connection.rs @@ -19,21 +19,21 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - config::Role, - crypto::{ - ed25519::Keypair, - noise::{self, NoiseSocket}, - }, - error::Error, - multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, - protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, - substream, - transport::{ - websocket::{stream::BufferedStream, substream::Substream}, - Endpoint, - }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, - BandwidthSink, PeerId, + config::Role, + crypto::{ + ed25519::Keypair, + noise::{self, NoiseSocket}, + }, + error::Error, + multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version}, + protocol::{Direction, Permit, ProtocolCommand, ProtocolSet}, + substream, + transport::{ + websocket::{stream::BufferedStream, substream::Substream}, + Endpoint, + }, + types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + BandwidthSink, PeerId, }; use futures::{future::BoxFuture, stream::FuturesUnordered, AsyncRead, AsyncWrite, StreamExt}; @@ -46,9 +46,9 @@ use url::Url; use std::time::Duration; mod schema { - pub(super) mod noise { - include!(concat!(env!("OUT_DIR"), "/noise.rs")); - } + pub(super) mod noise { + include!(concat!(env!("OUT_DIR"), "/noise.rs")); + } } /// Logging target for the file. @@ -56,513 +56,518 @@ const LOG_TARGET: &str = "litep2p::websocket::connection"; /// Negotiated substream and its context. pub struct NegotiatedSubstream { - /// Substream direction. - direction: Direction, + /// Substream direction. + direction: Direction, - /// Substream ID. - substream_id: SubstreamId, + /// Substream ID. + substream_id: SubstreamId, - /// Protocol name. - protocol: ProtocolName, + /// Protocol name. + protocol: ProtocolName, - /// Yamux substream. - io: crate::yamux::Stream, + /// Yamux substream. + io: crate::yamux::Stream, - /// Permit. - permit: Permit, + /// Permit. + permit: Permit, } /// WebSocket connection error. #[derive(Debug)] enum ConnectionError { - /// Timeout - Timeout { - /// Protocol. - protocol: Option, - - /// Substream ID. - substream_id: Option, - }, - - /// Failed to negotiate connection/substream. - FailedToNegotiate { - /// Protocol. - protocol: Option, - - /// Substream ID. - substream_id: Option, - - /// Error. - error: Error, - }, + /// Timeout + Timeout { + /// Protocol. + protocol: Option, + + /// Substream ID. + substream_id: Option, + }, + + /// Failed to negotiate connection/substream. + FailedToNegotiate { + /// Protocol. + protocol: Option, + + /// Substream ID. + substream_id: Option, + + /// Error. + error: Error, + }, } /// Negotiated connection. pub(super) struct NegotiatedConnection { - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Endpoint. - endpoint: Endpoint, + /// Endpoint. + endpoint: Endpoint, - /// Yamux connection. - connection: - crate::yamux::ControlledConnection>>>, + /// Yamux connection. + connection: + crate::yamux::ControlledConnection>>>, - /// Yamux control. - control: crate::yamux::Control, + /// Yamux control. + control: crate::yamux::Control, } impl NegotiatedConnection { - /// Get `ConnectionId` of the negotiated connection. - pub fn connection_id(&self) -> ConnectionId { - self.endpoint.connection_id() - } - - /// Get `PeerId` of the negotiated connection. - pub fn peer(&self) -> PeerId { - self.peer - } - - /// Get `Endpoint` of the negotiated connection. - pub fn endpoint(&self) -> Endpoint { - self.endpoint.clone() - } + /// Get `ConnectionId` of the negotiated connection. + pub fn connection_id(&self) -> ConnectionId { + self.endpoint.connection_id() + } + + /// Get `PeerId` of the negotiated connection. + pub fn peer(&self) -> PeerId { + self.peer + } + + /// Get `Endpoint` of the negotiated connection. + pub fn endpoint(&self) -> Endpoint { + self.endpoint.clone() + } } /// WebSocket connection. pub(crate) struct WebSocketConnection { - /// Protocol context. - protocol_set: ProtocolSet, + /// Protocol context. + protocol_set: ProtocolSet, - /// Yamux connection. - connection: - crate::yamux::ControlledConnection>>>, + /// Yamux connection. + connection: + crate::yamux::ControlledConnection>>>, - /// Yamux control. - control: crate::yamux::Control, + /// Yamux control. + control: crate::yamux::Control, - /// Remote peer ID. - peer: PeerId, + /// Remote peer ID. + peer: PeerId, - /// Endpoint. - endpoint: Endpoint, + /// Endpoint. + endpoint: Endpoint, - /// Substream open timeout. - substream_open_timeout: Duration, + /// Substream open timeout. + substream_open_timeout: Duration, - /// Connection ID. - connection_id: ConnectionId, + /// Connection ID. + connection_id: ConnectionId, - /// Bandwidth sink. - bandwidth_sink: BandwidthSink, + /// Bandwidth sink. + bandwidth_sink: BandwidthSink, - /// Pending substreams. - pending_substreams: - FuturesUnordered>>, + /// Pending substreams. + pending_substreams: + FuturesUnordered>>, } impl WebSocketConnection { - /// Create new [`WebSocketConnection`]. - pub(super) fn new( - connection: NegotiatedConnection, - protocol_set: ProtocolSet, - bandwidth_sink: BandwidthSink, - substream_open_timeout: Duration, - ) -> Self { - let NegotiatedConnection { peer, endpoint, connection, control } = connection; - - Self { - connection_id: endpoint.connection_id(), - protocol_set, - connection, - control, - peer, - endpoint, - bandwidth_sink, - substream_open_timeout, - pending_substreams: FuturesUnordered::new(), - } - } - - /// Negotiate protocol. - async fn negotiate_protocol( - stream: S, - role: &Role, - protocols: Vec<&str>, - ) -> crate::Result<(Negotiated, ProtocolName)> { - tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); - - let (protocol, socket) = match role { - Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await?, - Role::Listener => listener_select_proto(stream, protocols).await?, - }; - - tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); - - Ok((socket, ProtocolName::from(protocol.to_string()))) - } - - /// Open WebSocket connection. - pub(super) async fn open_connection( - connection_id: ConnectionId, - keypair: Keypair, - stream: WebSocketStream>, - address: Multiaddr, - dialed_peer: PeerId, - ws_address: Url, - yamux_config: crate::yamux::Config, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?address, - ?ws_address, - ?connection_id, - "open connection to remote peer", - ); - - Self::negotiate_connection( - stream, - Some(dialed_peer), - Role::Dialer, - address, - connection_id, - keypair, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await - } - - /// Accept WebSocket connection. - pub(super) async fn accept_connection( - stream: TcpStream, - connection_id: ConnectionId, - keypair: Keypair, - address: Multiaddr, - yamux_config: crate::yamux::Config, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - ) -> crate::Result { - let stream = MaybeTlsStream::Plain(stream); - - Self::negotiate_connection( - tokio_tungstenite::accept_async(stream).await?, - None, - Role::Listener, - address, - connection_id, - keypair, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await - } - - /// Negotiate WebSocket connection. - pub(super) async fn negotiate_connection( - stream: WebSocketStream>, - dialed_peer: Option, - role: Role, - address: Multiaddr, - connection_id: ConnectionId, - keypair: Keypair, - yamux_config: crate::yamux::Config, - max_read_ahead_factor: usize, - max_write_buffer_size: usize, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - ?role, - ?dialed_peer, - "negotiate connection" - ); - let stream = BufferedStream::new(stream); - - // negotiate `noise` - let (stream, _) = Self::negotiate_protocol(stream, &role, vec!["/noise"]).await?; - - tracing::trace!( - target: LOG_TARGET, - "`multistream-select` and `noise` negotiated" - ); - - // perform noise handshake - let (stream, peer) = noise::handshake( - stream.inner(), - &keypair, - role, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await?; - - if let Some(dialed_peer) = dialed_peer { - if peer != dialed_peer { - return Err(Error::PeerIdMismatch(dialed_peer, peer)); - } - } - - let stream: NoiseSocket> = stream; - - tracing::trace!(target: LOG_TARGET, "noise handshake done"); - - // negotiate `yamux` - let (stream, _) = Self::negotiate_protocol(stream, &role, vec!["/yamux/1.0.0"]).await?; - tracing::trace!(target: LOG_TARGET, "`yamux` negotiated"); - - let connection = crate::yamux::Connection::new(stream.inner(), yamux_config, role.into()); - let (control, connection) = crate::yamux::Control::new(connection); - - let address = match role { - Role::Dialer => address, - Role::Listener => address.with(Protocol::P2p(Multihash::from(peer))), - }; - - Ok(NegotiatedConnection { - peer, - control, - connection, - endpoint: match role { - Role::Dialer => Endpoint::dialer(address, connection_id), - Role::Listener => Endpoint::listener(address, connection_id), - }, - }) - } - - /// Accept substream. - pub async fn accept_substream( - stream: crate::yamux::Stream, - permit: Permit, - substream_id: SubstreamId, - protocols: Vec, - ) -> crate::Result { - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - "accept inbound substream" - ); - - let protocols = protocols.iter().map(|protocol| &**protocol).collect::>(); - let (io, protocol) = Self::negotiate_protocol(stream, &Role::Listener, protocols).await?; - - tracing::trace!( - target: LOG_TARGET, - ?substream_id, - "substream accepted and negotiated" - ); - - Ok(NegotiatedSubstream { - io: io.inner(), - direction: Direction::Inbound, - substream_id, - protocol, - permit, - }) - } - - /// Open substream for `protocol`. - pub async fn open_substream( - mut control: crate::yamux::Control, - permit: Permit, - substream_id: SubstreamId, - protocol: ProtocolName, - fallback_names: Vec, - ) -> crate::Result { - tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); - - let stream = match control.open_stream().await { - Ok(stream) => { - tracing::trace!(target: LOG_TARGET, ?substream_id, "substream opened"); - stream - }, - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?substream_id, - ?error, - "failed to open substream" - ); - return Err(Error::YamuxError(Direction::Outbound(substream_id), error)); - }, - }; - - // TODO: protocols don't change after they've been initialized so this should be done only - // once - let protocols = std::iter::once(&*protocol) - .chain(fallback_names.iter().map(|protocol| &**protocol)) - .collect(); - - let (io, protocol) = Self::negotiate_protocol(stream, &Role::Dialer, protocols).await?; - - Ok(NegotiatedSubstream { - io: io.inner(), - substream_id, - direction: Direction::Outbound(substream_id), - protocol, - permit, - }) - } - - /// Start connection event loop. - pub(crate) async fn start(mut self) -> crate::Result<()> { - self.protocol_set - .report_connection_established(self.peer, self.endpoint) - .await?; - - loop { - tokio::select! { - substream = self.connection.next() => match substream { - Some(Ok(stream)) => { - let substream = self.protocol_set.next_substream_id(); - let protocols = self.protocol_set.protocols(); - let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; - let substream_open_timeout = self.substream_open_timeout; - - self.pending_substreams.push(Box::pin(async move { - match tokio::time::timeout( - substream_open_timeout, - Self::accept_substream(stream, permit, substream, protocols), - ) - .await - { - Ok(Ok(substream)) => Ok(substream), - Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { - protocol: None, - substream_id: None, - error, - }), - Err(_) => Err(ConnectionError::Timeout { - protocol: None, - substream_id: None - }), - } - })); - }, - Some(Err(error)) => { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - ?error, - "connection closed with error" - ); - self.protocol_set.report_connection_closed(self.peer, self.connection_id).await?; - - return Ok(()) - } - None => { - tracing::debug!(target: LOG_TARGET, peer = ?self.peer, "connection closed"); - self.protocol_set.report_connection_closed(self.peer, self.connection_id).await?; - - return Ok(()) - } - }, - // TODO: move this to a function - substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => { - match substream { - // TODO: return error to protocol - Err(error) => { - tracing::debug!( - target: LOG_TARGET, - ?error, - "failed to accept/open substream", - ); - - let (protocol, substream_id, error) = match error { - ConnectionError::Timeout { protocol, substream_id } => { - (protocol, substream_id, Error::Timeout) - } - ConnectionError::FailedToNegotiate { protocol, substream_id, error } => { - (protocol, substream_id, error) - } - }; - - if let (Some(protocol), Some(substream_id)) = (protocol, substream_id) { - self.protocol_set - .report_substream_open_failure(protocol, substream_id, error) - .await?; - } - } - Ok(substream) => { - let protocol = substream.protocol.clone(); - let direction = substream.direction; - let substream_id = substream.substream_id; - let socket = FuturesAsyncReadCompatExt::compat(substream.io); - let bandwidth_sink = self.bandwidth_sink.clone(); - - let substream = substream::Substream::new_websocket( - self.peer, - substream_id, - Substream::new(socket, bandwidth_sink, substream.permit), - self.protocol_set.protocol_codec(&protocol) - ); - - self.protocol_set - .report_substream_open(self.peer, protocol, direction, substream) - .await?; - } - } - } - protocol = self.protocol_set.next() => match protocol { - Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { - let control = self.control.clone(); - let substream_open_timeout = self.substream_open_timeout; - - tracing::trace!( - target: LOG_TARGET, - ?protocol, - ?substream_id, - "open substream" - ); - - self.pending_substreams.push(Box::pin(async move { - match tokio::time::timeout( - substream_open_timeout, - Self::open_substream( - control, - permit, - substream_id, - protocol.clone(), - fallback_names - ), - ) - .await - { - Ok(Ok(substream)) => Ok(substream), - Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { - protocol: Some(protocol), - substream_id: Some(substream_id), - error, - }), - Err(_) => Err(ConnectionError::Timeout { - protocol: Some(protocol), - substream_id: Some(substream_id) - }), - } - })); - } - Some(ProtocolCommand::ForceClose) => { - tracing::debug!( - target: LOG_TARGET, - peer = ?self.peer, - connection_id = ?self.connection_id, - "force closing connection", - ); - - return self.protocol_set.report_connection_closed(self.peer, self.connection_id).await - } - None => { - tracing::debug!(target: LOG_TARGET, "protocols have exited, shutting down connection"); - return self.protocol_set.report_connection_closed(self.peer, self.connection_id).await - } - } - } - } - } + /// Create new [`WebSocketConnection`]. + pub(super) fn new( + connection: NegotiatedConnection, + protocol_set: ProtocolSet, + bandwidth_sink: BandwidthSink, + substream_open_timeout: Duration, + ) -> Self { + let NegotiatedConnection { + peer, + endpoint, + connection, + control, + } = connection; + + Self { + connection_id: endpoint.connection_id(), + protocol_set, + connection, + control, + peer, + endpoint, + bandwidth_sink, + substream_open_timeout, + pending_substreams: FuturesUnordered::new(), + } + } + + /// Negotiate protocol. + async fn negotiate_protocol( + stream: S, + role: &Role, + protocols: Vec<&str>, + ) -> crate::Result<(Negotiated, ProtocolName)> { + tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); + + let (protocol, socket) = match role { + Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await?, + Role::Listener => listener_select_proto(stream, protocols).await?, + }; + + tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); + + Ok((socket, ProtocolName::from(protocol.to_string()))) + } + + /// Open WebSocket connection. + pub(super) async fn open_connection( + connection_id: ConnectionId, + keypair: Keypair, + stream: WebSocketStream>, + address: Multiaddr, + dialed_peer: PeerId, + ws_address: Url, + yamux_config: crate::yamux::Config, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?address, + ?ws_address, + ?connection_id, + "open connection to remote peer", + ); + + Self::negotiate_connection( + stream, + Some(dialed_peer), + Role::Dialer, + address, + connection_id, + keypair, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await + } + + /// Accept WebSocket connection. + pub(super) async fn accept_connection( + stream: TcpStream, + connection_id: ConnectionId, + keypair: Keypair, + address: Multiaddr, + yamux_config: crate::yamux::Config, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + ) -> crate::Result { + let stream = MaybeTlsStream::Plain(stream); + + Self::negotiate_connection( + tokio_tungstenite::accept_async(stream).await?, + None, + Role::Listener, + address, + connection_id, + keypair, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await + } + + /// Negotiate WebSocket connection. + pub(super) async fn negotiate_connection( + stream: WebSocketStream>, + dialed_peer: Option, + role: Role, + address: Multiaddr, + connection_id: ConnectionId, + keypair: Keypair, + yamux_config: crate::yamux::Config, + max_read_ahead_factor: usize, + max_write_buffer_size: usize, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?address, + ?role, + ?dialed_peer, + "negotiate connection" + ); + let stream = BufferedStream::new(stream); + + // negotiate `noise` + let (stream, _) = Self::negotiate_protocol(stream, &role, vec!["/noise"]).await?; + + tracing::trace!( + target: LOG_TARGET, + "`multistream-select` and `noise` negotiated" + ); + + // perform noise handshake + let (stream, peer) = noise::handshake( + stream.inner(), + &keypair, + role, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await?; + + if let Some(dialed_peer) = dialed_peer { + if peer != dialed_peer { + return Err(Error::PeerIdMismatch(dialed_peer, peer)); + } + } + + let stream: NoiseSocket> = stream; + + tracing::trace!(target: LOG_TARGET, "noise handshake done"); + + // negotiate `yamux` + let (stream, _) = Self::negotiate_protocol(stream, &role, vec!["/yamux/1.0.0"]).await?; + tracing::trace!(target: LOG_TARGET, "`yamux` negotiated"); + + let connection = crate::yamux::Connection::new(stream.inner(), yamux_config, role.into()); + let (control, connection) = crate::yamux::Control::new(connection); + + let address = match role { + Role::Dialer => address, + Role::Listener => address.with(Protocol::P2p(Multihash::from(peer))), + }; + + Ok(NegotiatedConnection { + peer, + control, + connection, + endpoint: match role { + Role::Dialer => Endpoint::dialer(address, connection_id), + Role::Listener => Endpoint::listener(address, connection_id), + }, + }) + } + + /// Accept substream. + pub async fn accept_substream( + stream: crate::yamux::Stream, + permit: Permit, + substream_id: SubstreamId, + protocols: Vec, + ) -> crate::Result { + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + "accept inbound substream" + ); + + let protocols = protocols.iter().map(|protocol| &**protocol).collect::>(); + let (io, protocol) = Self::negotiate_protocol(stream, &Role::Listener, protocols).await?; + + tracing::trace!( + target: LOG_TARGET, + ?substream_id, + "substream accepted and negotiated" + ); + + Ok(NegotiatedSubstream { + io: io.inner(), + direction: Direction::Inbound, + substream_id, + protocol, + permit, + }) + } + + /// Open substream for `protocol`. + pub async fn open_substream( + mut control: crate::yamux::Control, + permit: Permit, + substream_id: SubstreamId, + protocol: ProtocolName, + fallback_names: Vec, + ) -> crate::Result { + tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); + + let stream = match control.open_stream().await { + Ok(stream) => { + tracing::trace!(target: LOG_TARGET, ?substream_id, "substream opened"); + stream + } + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?substream_id, + ?error, + "failed to open substream" + ); + return Err(Error::YamuxError(Direction::Outbound(substream_id), error)); + } + }; + + // TODO: protocols don't change after they've been initialized so this should be done only + // once + let protocols = std::iter::once(&*protocol) + .chain(fallback_names.iter().map(|protocol| &**protocol)) + .collect(); + + let (io, protocol) = Self::negotiate_protocol(stream, &Role::Dialer, protocols).await?; + + Ok(NegotiatedSubstream { + io: io.inner(), + substream_id, + direction: Direction::Outbound(substream_id), + protocol, + permit, + }) + } + + /// Start connection event loop. + pub(crate) async fn start(mut self) -> crate::Result<()> { + self.protocol_set + .report_connection_established(self.peer, self.endpoint) + .await?; + + loop { + tokio::select! { + substream = self.connection.next() => match substream { + Some(Ok(stream)) => { + let substream = self.protocol_set.next_substream_id(); + let protocols = self.protocol_set.protocols(); + let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?; + let substream_open_timeout = self.substream_open_timeout; + + self.pending_substreams.push(Box::pin(async move { + match tokio::time::timeout( + substream_open_timeout, + Self::accept_substream(stream, permit, substream, protocols), + ) + .await + { + Ok(Ok(substream)) => Ok(substream), + Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { + protocol: None, + substream_id: None, + error, + }), + Err(_) => Err(ConnectionError::Timeout { + protocol: None, + substream_id: None + }), + } + })); + }, + Some(Err(error)) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + ?error, + "connection closed with error" + ); + self.protocol_set.report_connection_closed(self.peer, self.connection_id).await?; + + return Ok(()) + } + None => { + tracing::debug!(target: LOG_TARGET, peer = ?self.peer, "connection closed"); + self.protocol_set.report_connection_closed(self.peer, self.connection_id).await?; + + return Ok(()) + } + }, + // TODO: move this to a function + substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => { + match substream { + // TODO: return error to protocol + Err(error) => { + tracing::debug!( + target: LOG_TARGET, + ?error, + "failed to accept/open substream", + ); + + let (protocol, substream_id, error) = match error { + ConnectionError::Timeout { protocol, substream_id } => { + (protocol, substream_id, Error::Timeout) + } + ConnectionError::FailedToNegotiate { protocol, substream_id, error } => { + (protocol, substream_id, error) + } + }; + + if let (Some(protocol), Some(substream_id)) = (protocol, substream_id) { + self.protocol_set + .report_substream_open_failure(protocol, substream_id, error) + .await?; + } + } + Ok(substream) => { + let protocol = substream.protocol.clone(); + let direction = substream.direction; + let substream_id = substream.substream_id; + let socket = FuturesAsyncReadCompatExt::compat(substream.io); + let bandwidth_sink = self.bandwidth_sink.clone(); + + let substream = substream::Substream::new_websocket( + self.peer, + substream_id, + Substream::new(socket, bandwidth_sink, substream.permit), + self.protocol_set.protocol_codec(&protocol) + ); + + self.protocol_set + .report_substream_open(self.peer, protocol, direction, substream) + .await?; + } + } + } + protocol = self.protocol_set.next() => match protocol { + Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => { + let control = self.control.clone(); + let substream_open_timeout = self.substream_open_timeout; + + tracing::trace!( + target: LOG_TARGET, + ?protocol, + ?substream_id, + "open substream" + ); + + self.pending_substreams.push(Box::pin(async move { + match tokio::time::timeout( + substream_open_timeout, + Self::open_substream( + control, + permit, + substream_id, + protocol.clone(), + fallback_names + ), + ) + .await + { + Ok(Ok(substream)) => Ok(substream), + Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate { + protocol: Some(protocol), + substream_id: Some(substream_id), + error, + }), + Err(_) => Err(ConnectionError::Timeout { + protocol: Some(protocol), + substream_id: Some(substream_id) + }), + } + })); + } + Some(ProtocolCommand::ForceClose) => { + tracing::debug!( + target: LOG_TARGET, + peer = ?self.peer, + connection_id = ?self.connection_id, + "force closing connection", + ); + + return self.protocol_set.report_connection_closed(self.peer, self.connection_id).await + } + None => { + tracing::debug!(target: LOG_TARGET, "protocols have exited, shutting down connection"); + return self.protocol_set.report_connection_closed(self.peer, self.connection_id).await + } + } + } + } + } } diff --git a/src/transport/websocket/listener.rs b/src/transport/websocket/listener.rs index cfb36a8e..0a9ea37b 100644 --- a/src/transport/websocket/listener.rs +++ b/src/transport/websocket/listener.rs @@ -29,11 +29,11 @@ use socket2::{Domain, Socket, Type}; use tokio::net::{TcpListener as TokioTcpListener, TcpStream}; use std::{ - io, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; /// Logging target for the file. @@ -42,435 +42,445 @@ const LOG_TARGET: &str = "litep2p::websocket::listener"; /// Address type. #[derive(Debug)] pub(super) enum AddressType { - /// Socket address. - Socket(SocketAddr), + /// Socket address. + Socket(SocketAddr), - /// DNS address. - Dns(String, u16), + /// DNS address. + Dns(String, u16), } /// WebSocket listener listening to zero or more addresses. pub struct WebSocketListener { - /// Listeners. - listeners: Vec, + /// Listeners. + listeners: Vec, } /// Local addresses to use for outbound connections. #[derive(Clone)] pub enum DialAddresses { - /// Reuse port from listen addresses. - Reuse { listen_addresses: Arc> }, - /// Do not reuse port. - NoReuse, + /// Reuse port from listen addresses. + Reuse { + listen_addresses: Arc>, + }, + /// Do not reuse port. + NoReuse, } impl Default for DialAddresses { - fn default() -> Self { - DialAddresses::NoReuse - } + fn default() -> Self { + DialAddresses::NoReuse + } } impl DialAddresses { - /// Get local dial address for an outbound connection. - pub(super) fn local_dial_address( - &self, - remote_address: &IpAddr, - ) -> Result, ()> { - match self { - DialAddresses::Reuse { listen_addresses } => { - for address in listen_addresses.iter() { - if remote_address.is_ipv4() == address.is_ipv4() && - remote_address.is_loopback() == address.ip().is_loopback() - { - if remote_address.is_ipv4() { - return Ok(Some(SocketAddr::new( - IpAddr::V4(Ipv4Addr::UNSPECIFIED), - address.port(), - ))); - } else { - return Ok(Some(SocketAddr::new( - IpAddr::V6(Ipv6Addr::UNSPECIFIED), - address.port(), - ))); - } - } - } - - Err(()) - }, - DialAddresses::NoReuse => Ok(None), - } - } + /// Get local dial address for an outbound connection. + pub(super) fn local_dial_address( + &self, + remote_address: &IpAddr, + ) -> Result, ()> { + match self { + DialAddresses::Reuse { listen_addresses } => { + for address in listen_addresses.iter() { + if remote_address.is_ipv4() == address.is_ipv4() + && remote_address.is_loopback() == address.ip().is_loopback() + { + if remote_address.is_ipv4() { + return Ok(Some(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + address.port(), + ))); + } else { + return Ok(Some(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + address.port(), + ))); + } + } + } + + Err(()) + } + DialAddresses::NoReuse => Ok(None), + } + } } impl WebSocketListener { - /// Create new [`WebSocketListener`] - pub fn new( - addresses: Vec, - reuse_port: bool, - ) -> (Self, Vec, DialAddresses) { - let (listeners, listen_addresses): (_, Vec>) = addresses - .into_iter() - .filter_map(|address| { - let address = match Self::get_socket_address(&address).ok()?.0 { - AddressType::Socket(address) => address, - AddressType::Dns(address, port) => { - tracing::debug!( - target: LOG_TARGET, - ?address, - ?port, - "dns not supported as bind address" - ); - - return None; - }, - }; - let socket = match address.is_ipv4() { - false => { - let socket = - Socket::new(Domain::IPV6, Type::STREAM, Some(socket2::Protocol::TCP)) - .ok()?; - socket.set_only_v6(true).ok()?; - - socket - }, - true => Socket::new(Domain::IPV4, Type::STREAM, Some(socket2::Protocol::TCP)) - .ok()?, - }; - - socket.set_nonblocking(true).ok()?; - socket.set_reuse_address(true).ok()?; - #[cfg(unix)] - if reuse_port { - socket.set_reuse_port(true).ok()?; - } - socket.bind(&address.into()).ok()?; - socket.listen(1024).ok()?; - - let socket: std::net::TcpListener = socket.into(); - let listener = TokioTcpListener::from_std(socket).ok()?; - let local_address = listener.local_addr().ok()?; - - let listen_addresses = match address.ip().is_unspecified() { - true => match NetworkInterface::show() { - Ok(ifaces) => ifaces - .into_iter() - .flat_map(|record| { - record.addr.into_iter().filter_map(|iface_address| { - match (iface_address, address.is_ipv4()) { - (Addr::V4(inner), true) => Some(SocketAddr::new( - IpAddr::V4(inner.ip), - local_address.port(), - )), - (Addr::V6(inner), false) => - match inner.ip.segments().get(0) { - Some(0xfe80) => None, - _ => Some(SocketAddr::new( - IpAddr::V6(inner.ip), - local_address.port(), - )), - }, - _ => None, - } - }) - }) - .collect(), - Err(error) => { - tracing::warn!( - target: LOG_TARGET, - ?error, - "failed to fetch network interfaces", - ); - - return None; - }, - }, - false => vec![local_address], - }; - - Some((listener, listen_addresses)) - }) - .unzip(); - - let listen_addresses = listen_addresses.into_iter().flatten().collect::>(); - let listen_multi_addresses = listen_addresses - .iter() - .cloned() - .map(|address| { - Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) - }) - .collect(); - let dial_addresses = if reuse_port { - DialAddresses::Reuse { listen_addresses: Arc::new(listen_addresses) } - } else { - DialAddresses::NoReuse - }; - - (Self { listeners }, listen_multi_addresses, dial_addresses) - } - - /// Extract socket address and `PeerId`, if found, from `address`. - pub(super) fn get_socket_address( - address: &Multiaddr, - ) -> crate::Result<(AddressType, Option)> { - tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); - - let mut iter = address.iter(); - let socket_address = match iter.next() { - Some(Protocol::Ip6(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => - AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Tcp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - Some(Protocol::Ip4(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => - AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Tcp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - Some(Protocol::Dns(address)) | - Some(Protocol::Dns4(address)) | - Some(Protocol::Dns6(address)) => match iter.next() { - Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port), - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid transport protocol, expected `Tcp`", - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }, - protocol => { - tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - // verify that `/ws`/`/wss` is part of the multi address - match iter.next() { - Some(Protocol::Ws(_address)) => {}, - Some(Protocol::Wss(_address)) => {}, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `Ws` or `Wss`" - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - let maybe_peer = match iter.next() { - Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), - None => None, - protocol => { - tracing::error!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `P2p` or `None`" - ); - return Err(Error::AddressError(AddressError::InvalidProtocol)); - }, - }; - - Ok((socket_address, maybe_peer)) - } + /// Create new [`WebSocketListener`] + pub fn new( + addresses: Vec, + reuse_port: bool, + ) -> (Self, Vec, DialAddresses) { + let (listeners, listen_addresses): (_, Vec>) = addresses + .into_iter() + .filter_map(|address| { + let address = match Self::get_socket_address(&address).ok()?.0 { + AddressType::Socket(address) => address, + AddressType::Dns(address, port) => { + tracing::debug!( + target: LOG_TARGET, + ?address, + ?port, + "dns not supported as bind address" + ); + + return None; + } + }; + let socket = match address.is_ipv4() { + false => { + let socket = + Socket::new(Domain::IPV6, Type::STREAM, Some(socket2::Protocol::TCP)) + .ok()?; + socket.set_only_v6(true).ok()?; + + socket + } + true => Socket::new(Domain::IPV4, Type::STREAM, Some(socket2::Protocol::TCP)) + .ok()?, + }; + + socket.set_nonblocking(true).ok()?; + socket.set_reuse_address(true).ok()?; + #[cfg(unix)] + if reuse_port { + socket.set_reuse_port(true).ok()?; + } + socket.bind(&address.into()).ok()?; + socket.listen(1024).ok()?; + + let socket: std::net::TcpListener = socket.into(); + let listener = TokioTcpListener::from_std(socket).ok()?; + let local_address = listener.local_addr().ok()?; + + let listen_addresses = match address.ip().is_unspecified() { + true => match NetworkInterface::show() { + Ok(ifaces) => ifaces + .into_iter() + .flat_map(|record| { + record.addr.into_iter().filter_map(|iface_address| { + match (iface_address, address.is_ipv4()) { + (Addr::V4(inner), true) => Some(SocketAddr::new( + IpAddr::V4(inner.ip), + local_address.port(), + )), + (Addr::V6(inner), false) => + match inner.ip.segments().get(0) { + Some(0xfe80) => None, + _ => Some(SocketAddr::new( + IpAddr::V6(inner.ip), + local_address.port(), + )), + }, + _ => None, + } + }) + }) + .collect(), + Err(error) => { + tracing::warn!( + target: LOG_TARGET, + ?error, + "failed to fetch network interfaces", + ); + + return None; + } + }, + false => vec![local_address], + }; + + Some((listener, listen_addresses)) + }) + .unzip(); + + let listen_addresses = listen_addresses.into_iter().flatten().collect::>(); + let listen_multi_addresses = listen_addresses + .iter() + .cloned() + .map(|address| { + Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) + }) + .collect(); + let dial_addresses = if reuse_port { + DialAddresses::Reuse { + listen_addresses: Arc::new(listen_addresses), + } + } else { + DialAddresses::NoReuse + }; + + (Self { listeners }, listen_multi_addresses, dial_addresses) + } + + /// Extract socket address and `PeerId`, if found, from `address`. + pub(super) fn get_socket_address( + address: &Multiaddr, + ) -> crate::Result<(AddressType, Option)> { + tracing::trace!(target: LOG_TARGET, ?address, "parse multi address"); + + let mut iter = address.iter(); + let socket_address = match iter.next() { + Some(Protocol::Ip6(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => + AddressType::Socket(SocketAddr::new(IpAddr::V6(address), port)), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Tcp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + Some(Protocol::Ip4(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => + AddressType::Socket(SocketAddr::new(IpAddr::V4(address), port)), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Tcp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + Some(Protocol::Dns(address)) + | Some(Protocol::Dns4(address)) + | Some(Protocol::Dns6(address)) => match iter.next() { + Some(Protocol::Tcp(port)) => AddressType::Dns(address.to_string(), port), + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid transport protocol, expected `Tcp`", + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }, + protocol => { + tracing::error!(target: LOG_TARGET, ?protocol, "invalid transport protocol"); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + // verify that `/ws`/`/wss` is part of the multi address + match iter.next() { + Some(Protocol::Ws(_address)) => {} + Some(Protocol::Wss(_address)) => {} + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `Ws` or `Wss`" + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + let maybe_peer = match iter.next() { + Some(Protocol::P2p(multihash)) => Some(PeerId::from_multihash(multihash)?), + None => None, + protocol => { + tracing::error!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `P2p` or `None`" + ); + return Err(Error::AddressError(AddressError::InvalidProtocol)); + } + }; + + Ok((socket_address, maybe_peer)) + } } impl Stream for WebSocketListener { - type Item = io::Result<(TcpStream, SocketAddr)>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.listeners.is_empty() { - return Poll::Pending; - } - - // TODO: make this more fair - for listener in self.listeners.iter_mut() { - match listener.poll_accept(cx) { - Poll::Pending => {}, - Poll::Ready(Err(error)) => return Poll::Ready(Some(Err(error))), - Poll::Ready(Ok((stream, address))) => - return Poll::Ready(Some(Ok((stream, address)))), - } - } - - Poll::Pending - } + type Item = io::Result<(TcpStream, SocketAddr)>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.listeners.is_empty() { + return Poll::Pending; + } + + // TODO: make this more fair + for listener in self.listeners.iter_mut() { + match listener.poll_accept(cx) { + Poll::Pending => {} + Poll::Ready(Err(error)) => return Poll::Ready(Some(Err(error))), + Poll::Ready(Ok((stream, address))) => + return Poll::Ready(Some(Ok((stream, address)))), + } + } + + Poll::Pending + } } #[cfg(test)] mod tests { - use super::*; - use futures::StreamExt; - - #[test] - fn parse_multiaddresses() { - assert!(WebSocketListener::get_socket_address( - &"/ip6/::1/tcp/8888/ws".parse().expect("valid multiaddress") - ) - .is_ok()); - assert!(WebSocketListener::get_socket_address( - &"/ip4/127.0.0.1/tcp/8888/ws".parse().expect("valid multiaddress") - ) - .is_ok()); - assert!(WebSocketListener::get_socket_address( - &"/ip6/::1/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_ok()); - assert!(WebSocketListener::get_socket_address( - &"/ip4/127.0.0.1/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_ok()); - assert!(WebSocketListener::get_socket_address( - &"/ip6/::1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(WebSocketListener::get_socket_address( - &"/ip4/127.0.0.1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(WebSocketListener::get_socket_address( - &"/ip4/127.0.0.1/tcp/8888/ws/utp".parse().expect("valid multiaddress") - ) - .is_err()); - assert!(WebSocketListener::get_socket_address( - &"/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(WebSocketListener::get_socket_address( - &"/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(WebSocketListener::get_socket_address( - &"/dns/hello.world/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" - .parse() - .expect("valid multiaddress") - ) - .is_err()); - assert!(WebSocketListener::get_socket_address( + use super::*; + use futures::StreamExt; + + #[test] + fn parse_multiaddresses() { + assert!(WebSocketListener::get_socket_address( + &"/ip6/::1/tcp/8888/ws".parse().expect("valid multiaddress") + ) + .is_ok()); + assert!(WebSocketListener::get_socket_address( + &"/ip4/127.0.0.1/tcp/8888/ws".parse().expect("valid multiaddress") + ) + .is_ok()); + assert!(WebSocketListener::get_socket_address( + &"/ip6/::1/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_ok()); + assert!(WebSocketListener::get_socket_address( + &"/ip4/127.0.0.1/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_ok()); + assert!(WebSocketListener::get_socket_address( + &"/ip6/::1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(WebSocketListener::get_socket_address( + &"/ip4/127.0.0.1/udp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(WebSocketListener::get_socket_address( + &"/ip4/127.0.0.1/tcp/8888/ws/utp".parse().expect("valid multiaddress") + ) + .is_err()); + assert!(WebSocketListener::get_socket_address( + &"/ip6/::1/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(WebSocketListener::get_socket_address( + &"/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(WebSocketListener::get_socket_address( + &"/dns/hello.world/tcp/8888/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" + .parse() + .expect("valid multiaddress") + ) + .is_err()); + assert!(WebSocketListener::get_socket_address( &"/dns6/hello.world/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" .parse() .expect("valid multiaddress") ) .is_ok()); - assert!(WebSocketListener::get_socket_address( + assert!(WebSocketListener::get_socket_address( &"/dns4/hello.world/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" .parse() .expect("valid multiaddress") ) .is_ok()); - assert!(WebSocketListener::get_socket_address( + assert!(WebSocketListener::get_socket_address( &"/dns6/hello.world/tcp/8888/ws/p2p/12D3KooWT2ouvz5uMmCvHJGzAGRHiqDts5hzXR7NdoQ27pGdzp9Q" .parse() .expect("valid multiaddress") ) .is_ok()); - } - - #[tokio::test] - async fn no_listeners() { - let (mut listener, _, _) = WebSocketListener::new(Vec::new(), true); - - futures::future::poll_fn(|cx| match listener.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("unexpected event: {event:?}"), - }) - .await; - } - - #[tokio::test] - async fn one_listener() { - let address: Multiaddr = "/ip6/::1/tcp/0/ws".parse().unwrap(); - let (mut listener, listen_addresses, _) = - WebSocketListener::new(vec![address.clone()], true); - let Some(Protocol::Tcp(port)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let (res1, res2) = - tokio::join!(listener.next(), TcpStream::connect(format!("[::1]:{port}"))); - - assert!(res1.unwrap().is_ok() && res2.is_ok()); - } - - #[tokio::test] - async fn two_listeners() { - let address1: Multiaddr = "/ip6/::1/tcp/0/ws".parse().unwrap(); - let address2: Multiaddr = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap(); - let (mut listener, listen_addresses, _) = - WebSocketListener::new(vec![address1, address2], true); - - let Some(Protocol::Tcp(port1)) = - listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - let Some(Protocol::Tcp(port2)) = - listen_addresses.iter().skip(1).next().unwrap().clone().iter().skip(1).next() - else { - panic!("invalid address"); - }; - - tokio::spawn(async move { while let Some(_) = listener.next().await {} }); - - let (res1, res2) = tokio::join!( - TcpStream::connect(format!("[::1]:{port1}")), - TcpStream::connect(format!("127.0.0.1:{port2}")) - ); - - assert!(res1.is_ok() && res2.is_ok()); - } - - #[tokio::test] - async fn local_dial_address() { - let dial_addresses = DialAddresses::Reuse { - listen_addresses: Arc::new(vec![ - "[2001:7d0:84aa:3900:2a5d:9e85::]:8888".parse().unwrap(), - "92.168.127.1:9999".parse().unwrap(), - ]), - }; - - assert_eq!( - dial_addresses.local_dial_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))), - Ok(Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 9999))), - ); - - assert_eq!( - dial_addresses.local_dial_address(&IpAddr::V6(Ipv6Addr::new(0, 1, 2, 3, 4, 5, 6, 7))), - Ok(Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 8888))), - ); - } + } + + #[tokio::test] + async fn no_listeners() { + let (mut listener, _, _) = WebSocketListener::new(Vec::new(), true); + + futures::future::poll_fn(|cx| match listener.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("unexpected event: {event:?}"), + }) + .await; + } + + #[tokio::test] + async fn one_listener() { + let address: Multiaddr = "/ip6/::1/tcp/0/ws".parse().unwrap(); + let (mut listener, listen_addresses, _) = + WebSocketListener::new(vec![address.clone()], true); + let Some(Protocol::Tcp(port)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let (res1, res2) = + tokio::join!(listener.next(), TcpStream::connect(format!("[::1]:{port}"))); + + assert!(res1.unwrap().is_ok() && res2.is_ok()); + } + + #[tokio::test] + async fn two_listeners() { + let address1: Multiaddr = "/ip6/::1/tcp/0/ws".parse().unwrap(); + let address2: Multiaddr = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap(); + let (mut listener, listen_addresses, _) = + WebSocketListener::new(vec![address1, address2], true); + + let Some(Protocol::Tcp(port1)) = + listen_addresses.iter().next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + let Some(Protocol::Tcp(port2)) = + listen_addresses.iter().skip(1).next().unwrap().clone().iter().skip(1).next() + else { + panic!("invalid address"); + }; + + tokio::spawn(async move { while let Some(_) = listener.next().await {} }); + + let (res1, res2) = tokio::join!( + TcpStream::connect(format!("[::1]:{port1}")), + TcpStream::connect(format!("127.0.0.1:{port2}")) + ); + + assert!(res1.is_ok() && res2.is_ok()); + } + + #[tokio::test] + async fn local_dial_address() { + let dial_addresses = DialAddresses::Reuse { + listen_addresses: Arc::new(vec![ + "[2001:7d0:84aa:3900:2a5d:9e85::]:8888".parse().unwrap(), + "92.168.127.1:9999".parse().unwrap(), + ]), + }; + + assert_eq!( + dial_addresses.local_dial_address(&IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))), + Ok(Some(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + 9999 + ))), + ); + + assert_eq!( + dial_addresses.local_dial_address(&IpAddr::V6(Ipv6Addr::new(0, 1, 2, 3, 4, 5, 6, 7))), + Ok(Some(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + 8888 + ))), + ); + } } diff --git a/src/transport/websocket/mod.rs b/src/transport/websocket/mod.rs index 6dad6720..35a4ce58 100644 --- a/src/transport/websocket/mod.rs +++ b/src/transport/websocket/mod.rs @@ -21,19 +21,19 @@ //! WebSocket transport. use crate::{ - config::Role, - error::{AddressError, Error}, - transport::{ - manager::TransportHandle, - websocket::{ - config::Config, - connection::{NegotiatedConnection, WebSocketConnection}, - listener::{AddressType, DialAddresses, WebSocketListener}, - }, - Transport, TransportBuilder, TransportEvent, - }, - types::ConnectionId, - PeerId, + config::Role, + error::{AddressError, Error}, + transport::{ + manager::TransportHandle, + websocket::{ + config::Config, + connection::{NegotiatedConnection, WebSocketConnection}, + listener::{AddressType, DialAddresses, WebSocketListener}, + }, + Transport, TransportBuilder, TransportEvent, + }, + types::ConnectionId, + PeerId, }; use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt}; @@ -42,17 +42,17 @@ use socket2::{Domain, Socket, Type}; use tokio::net::TcpStream; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use trust_dns_resolver::{ - config::{ResolverConfig, ResolverOpts}, - TokioAsyncResolver, + config::{ResolverConfig, ResolverOpts}, + TokioAsyncResolver, }; use url::Url; use std::{ - collections::{HashMap, HashSet}, - net::SocketAddr, - pin::Pin, - task::{Context, Poll}, - time::Duration, + collections::{HashMap, HashSet}, + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, + time::Duration, }; pub(crate) use substream::Substream; @@ -66,17 +66,20 @@ pub mod config; #[derive(Debug)] pub(super) struct WebSocketError { - /// Error. - error: Error, + /// Error. + error: Error, - /// Connection ID. - connection_id: Option, + /// Connection ID. + connection_id: Option, } impl WebSocketError { - pub fn new(error: Error, connection_id: Option) -> Self { - Self { error, connection_id } - } + pub fn new(error: Error, connection_id: Option) -> Self { + Self { + error, + connection_id, + } + } } /// Logging target for the file. @@ -84,535 +87,542 @@ const LOG_TARGET: &str = "litep2p::websocket"; /// WebSocket transport. pub(crate) struct WebSocketTransport { - /// Transport context. - context: TransportHandle, - - /// Transport configuration. - config: Config, - - /// WebSocket listener. - listener: WebSocketListener, - - /// Dial addresses. - dial_addresses: DialAddresses, - - /// Pending dials. - pending_dials: HashMap, - - /// Pending connections. - pending_connections: - FuturesUnordered>>, - - /// Pending raw, unnegotiated connections. - pending_raw_connections: FuturesUnordered< - BoxFuture< - 'static, - Result< - (ConnectionId, Multiaddr, WebSocketStream>), - ConnectionId, - >, - >, - >, - - /// Opened raw connection, waiting for approval/rejection from `TransportManager`. - opened_raw: HashMap>, Multiaddr)>, - - /// Canceled raw connections. - canceled: HashSet, - - /// Negotiated connections waiting validation. - pending_open: HashMap, + /// Transport context. + context: TransportHandle, + + /// Transport configuration. + config: Config, + + /// WebSocket listener. + listener: WebSocketListener, + + /// Dial addresses. + dial_addresses: DialAddresses, + + /// Pending dials. + pending_dials: HashMap, + + /// Pending connections. + pending_connections: + FuturesUnordered>>, + + /// Pending raw, unnegotiated connections. + pending_raw_connections: FuturesUnordered< + BoxFuture< + 'static, + Result< + ( + ConnectionId, + Multiaddr, + WebSocketStream>, + ), + ConnectionId, + >, + >, + >, + + /// Opened raw connection, waiting for approval/rejection from `TransportManager`. + opened_raw: HashMap>, Multiaddr)>, + + /// Canceled raw connections. + canceled: HashSet, + + /// Negotiated connections waiting validation. + pending_open: HashMap, } impl WebSocketTransport { - /// Convert `Multiaddr` into `url::Url` - fn multiaddr_into_url(address: Multiaddr) -> crate::Result<(Url, PeerId)> { - let mut protocol_stack = address.iter(); - - let dial_address = match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(address.clone()))? - { - Protocol::Ip4(address) => address.to_string(), - Protocol::Ip6(address) => format!("[{}]", address.to_string()), - Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) => - address.to_string(), - - _ => return Err(Error::TransportNotSupported(address)), - }; - - let url = match protocol_stack - .next() - .ok_or_else(|| Error::TransportNotSupported(address.clone()))? - { - Protocol::Tcp(port) => match protocol_stack.next() { - Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"), - Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"), - _ => return Err(Error::TransportNotSupported(address.clone())), - }, - _ => return Err(Error::TransportNotSupported(address)), - }; - - let peer = match protocol_stack.next() { - Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?, - protocol => { - tracing::warn!( - target: LOG_TARGET, - ?protocol, - "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`", - ); - return Err(Error::AddressError(AddressError::PeerIdMissing)); - }, - }; - - tracing::trace!(target: LOG_TARGET, ?url, "parse address"); - - url::Url::parse(&url).map(|url| (url, peer)).map_err(|_| Error::InvalidData) - } - - /// Dial remote peer over `address`. - async fn dial_peer( - address: Multiaddr, - dial_addresses: DialAddresses, - connection_open_timeout: Duration, - ) -> crate::Result<(Multiaddr, WebSocketStream>)> { - let (url, _) = Self::multiaddr_into_url(address.clone())?; - let (socket_address, _) = WebSocketListener::get_socket_address(&address)?; - - let remote_address = match socket_address { - AddressType::Socket(address) => address, - AddressType::Dns(url, port) => { - let address = address.clone(); - let future = async move { - match TokioAsyncResolver::tokio( - ResolverConfig::default(), - ResolverOpts::default(), - ) - .lookup_ip(url.clone()) - .await - { - // TODO: ugly - Ok(lookup) => { - let mut iter = lookup.iter(); - while let Some(ip) = iter.next() { - match ( - address.iter().next().expect("protocol to exist"), - ip.is_ipv4(), - ) { - (Protocol::Dns(_), true) | - (Protocol::Dns4(_), true) | - (Protocol::Dns6(_), false) => { - tracing::trace!( - target: LOG_TARGET, - ?address, - ?ip, - "address resolved", - ); - - return Ok(SocketAddr::new(ip, port)); - }, - _ => {}, - } - } - - Err(Error::Unknown) - }, - Err(_) => Err(Error::Unknown), - } - }; - - match tokio::time::timeout(connection_open_timeout, future).await { - Err(_) => return Err(Error::Timeout), - Ok(Err(error)) => return Err(error), - Ok(Ok(address)) => address, - } - }, - }; - - let domain = match remote_address.is_ipv4() { - true => Domain::IPV4, - false => Domain::IPV6, - }; - let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?; - if remote_address.is_ipv6() { - socket.set_only_v6(true)?; - } - socket.set_nonblocking(true)?; - - match dial_addresses.local_dial_address(&remote_address.ip()) { - Ok(Some(dial_address)) => { - socket.set_reuse_address(true)?; - #[cfg(unix)] - socket.set_reuse_port(true)?; - socket.bind(&dial_address.into())?; - }, - Ok(None) => {}, - Err(()) => { - tracing::debug!( - target: LOG_TARGET, - ?remote_address, - "tcp listener not enabled for remote address, using ephemeral port", - ); - }, - } - - let future = async move { - match socket.connect(&remote_address.into()) { - Ok(()) => {}, - Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {}, - Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {}, - Err(error) => return Err(Error::Other(error.to_string())), - } - - let stream = TcpStream::try_from(Into::::into(socket)) - .map_err(|error| Error::Other(error.to_string()))?; - stream.writable().await.map_err(|error| Error::Other(error.to_string()))?; - - if let Some(error) = - stream.take_error().map_err(|error| Error::Other(error.to_string()))? - { - return Err(Error::Other(error.to_string())); - } - - Ok((address, tokio_tungstenite::client_async_tls(url, stream).await?.0)) - }; - - match tokio::time::timeout(connection_open_timeout, future).await { - Err(_) => Err(Error::Timeout), - Ok(Err(error)) => Err(error.into()), - Ok(Ok((address, stream))) => Ok((address, stream)), - } - } + /// Convert `Multiaddr` into `url::Url` + fn multiaddr_into_url(address: Multiaddr) -> crate::Result<(Url, PeerId)> { + let mut protocol_stack = address.iter(); + + let dial_address = match protocol_stack + .next() + .ok_or_else(|| Error::TransportNotSupported(address.clone()))? + { + Protocol::Ip4(address) => address.to_string(), + Protocol::Ip6(address) => format!("[{}]", address.to_string()), + Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) => + address.to_string(), + + _ => return Err(Error::TransportNotSupported(address)), + }; + + let url = match protocol_stack + .next() + .ok_or_else(|| Error::TransportNotSupported(address.clone()))? + { + Protocol::Tcp(port) => match protocol_stack.next() { + Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"), + Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"), + _ => return Err(Error::TransportNotSupported(address.clone())), + }, + _ => return Err(Error::TransportNotSupported(address)), + }; + + let peer = match protocol_stack.next() { + Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?, + protocol => { + tracing::warn!( + target: LOG_TARGET, + ?protocol, + "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`", + ); + return Err(Error::AddressError(AddressError::PeerIdMissing)); + } + }; + + tracing::trace!(target: LOG_TARGET, ?url, "parse address"); + + url::Url::parse(&url).map(|url| (url, peer)).map_err(|_| Error::InvalidData) + } + + /// Dial remote peer over `address`. + async fn dial_peer( + address: Multiaddr, + dial_addresses: DialAddresses, + connection_open_timeout: Duration, + ) -> crate::Result<(Multiaddr, WebSocketStream>)> { + let (url, _) = Self::multiaddr_into_url(address.clone())?; + let (socket_address, _) = WebSocketListener::get_socket_address(&address)?; + + let remote_address = match socket_address { + AddressType::Socket(address) => address, + AddressType::Dns(url, port) => { + let address = address.clone(); + let future = async move { + match TokioAsyncResolver::tokio( + ResolverConfig::default(), + ResolverOpts::default(), + ) + .lookup_ip(url.clone()) + .await + { + // TODO: ugly + Ok(lookup) => { + let mut iter = lookup.iter(); + while let Some(ip) = iter.next() { + match ( + address.iter().next().expect("protocol to exist"), + ip.is_ipv4(), + ) { + (Protocol::Dns(_), true) + | (Protocol::Dns4(_), true) + | (Protocol::Dns6(_), false) => { + tracing::trace!( + target: LOG_TARGET, + ?address, + ?ip, + "address resolved", + ); + + return Ok(SocketAddr::new(ip, port)); + } + _ => {} + } + } + + Err(Error::Unknown) + } + Err(_) => Err(Error::Unknown), + } + }; + + match tokio::time::timeout(connection_open_timeout, future).await { + Err(_) => return Err(Error::Timeout), + Ok(Err(error)) => return Err(error), + Ok(Ok(address)) => address, + } + } + }; + + let domain = match remote_address.is_ipv4() { + true => Domain::IPV4, + false => Domain::IPV6, + }; + let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?; + if remote_address.is_ipv6() { + socket.set_only_v6(true)?; + } + socket.set_nonblocking(true)?; + + match dial_addresses.local_dial_address(&remote_address.ip()) { + Ok(Some(dial_address)) => { + socket.set_reuse_address(true)?; + #[cfg(unix)] + socket.set_reuse_port(true)?; + socket.bind(&dial_address.into())?; + } + Ok(None) => {} + Err(()) => { + tracing::debug!( + target: LOG_TARGET, + ?remote_address, + "tcp listener not enabled for remote address, using ephemeral port", + ); + } + } + + let future = async move { + match socket.connect(&remote_address.into()) { + Ok(()) => {} + Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {} + Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {} + Err(error) => return Err(Error::Other(error.to_string())), + } + + let stream = TcpStream::try_from(Into::::into(socket)) + .map_err(|error| Error::Other(error.to_string()))?; + stream.writable().await.map_err(|error| Error::Other(error.to_string()))?; + + if let Some(error) = + stream.take_error().map_err(|error| Error::Other(error.to_string()))? + { + return Err(Error::Other(error.to_string())); + } + + Ok(( + address, + tokio_tungstenite::client_async_tls(url, stream).await?.0, + )) + }; + + match tokio::time::timeout(connection_open_timeout, future).await { + Err(_) => Err(Error::Timeout), + Ok(Err(error)) => Err(error.into()), + Ok(Ok((address, stream))) => Ok((address, stream)), + } + } } impl TransportBuilder for WebSocketTransport { - type Config = Config; - type Transport = WebSocketTransport; - - /// Create new [`Transport`] object. - fn new( - context: TransportHandle, - mut config: Self::Config, - ) -> crate::Result<(Self, Vec)> - where - Self: Sized, - { - tracing::debug!( - target: LOG_TARGET, - listen_addresses = ?config.listen_addresses, - "start websocket transport", - ); - let (listener, listen_addresses, dial_addresses) = WebSocketListener::new( - std::mem::replace(&mut config.listen_addresses, Vec::new()), - config.reuse_port, - ); - - Ok(( - Self { - listener, - config, - context, - dial_addresses, - canceled: HashSet::new(), - opened_raw: HashMap::new(), - pending_open: HashMap::new(), - pending_dials: HashMap::new(), - pending_connections: FuturesUnordered::new(), - pending_raw_connections: FuturesUnordered::new(), - }, - listen_addresses, - )) - } + type Config = Config; + type Transport = WebSocketTransport; + + /// Create new [`Transport`] object. + fn new( + context: TransportHandle, + mut config: Self::Config, + ) -> crate::Result<(Self, Vec)> + where + Self: Sized, + { + tracing::debug!( + target: LOG_TARGET, + listen_addresses = ?config.listen_addresses, + "start websocket transport", + ); + let (listener, listen_addresses, dial_addresses) = WebSocketListener::new( + std::mem::replace(&mut config.listen_addresses, Vec::new()), + config.reuse_port, + ); + + Ok(( + Self { + listener, + config, + context, + dial_addresses, + canceled: HashSet::new(), + opened_raw: HashMap::new(), + pending_open: HashMap::new(), + pending_dials: HashMap::new(), + pending_connections: FuturesUnordered::new(), + pending_raw_connections: FuturesUnordered::new(), + }, + listen_addresses, + )) + } } impl Transport for WebSocketTransport { - fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { - let yamux_config = self.config.yamux_config.clone(); - let keypair = self.context.keypair.clone(); - let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?; - let connection_open_timeout = self.config.connection_open_timeout; - let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; - let max_write_buffer_size = self.config.noise_write_buffer_size; - let dial_addresses = self.dial_addresses.clone(); - self.pending_dials.insert(connection_id, address.clone()); - - tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection"); - - self.pending_connections.push(Box::pin(async move { - match tokio::time::timeout(connection_open_timeout, async move { - let (_, stream) = WebSocketTransport::dial_peer( - address.clone(), - dial_addresses, - connection_open_timeout, - ) - .await - .map_err(|error| WebSocketError::new(error, Some(connection_id)))?; - - WebSocketConnection::open_connection( - connection_id, - keypair, - stream, - address, - peer, - ws_address, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await - .map_err(|error| WebSocketError::new(error, Some(connection_id))) - }) - .await - { - Err(_) => Err(WebSocketError::new(Error::Timeout, Some(connection_id))), - Ok(Err(error)) => Err(error), - Ok(Ok(result)) => Ok(result), - } - })); - - Ok(()) - } - - fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let context = self - .pending_open - .remove(&connection_id) - .ok_or(Error::ConnectionDoesntExist(connection_id))?; - let protocol_set = self.context.protocol_set(connection_id); - let bandwidth_sink = self.context.bandwidth_sink.clone(); - let substream_open_timeout = self.config.substream_open_timeout; - - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - "start connection", - ); - - self.context.executor.run(Box::pin(async move { - if let Err(error) = WebSocketConnection::new( - context, - protocol_set, - bandwidth_sink, - substream_open_timeout, - ) - .start() - .await - { - tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "connection exited with error", - ); - } - })); - - Ok(()) - } - - fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - self.canceled.insert(connection_id); - self.pending_open - .remove(&connection_id) - .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(())) - } - - fn open( - &mut self, - connection_id: ConnectionId, - addresses: Vec, - ) -> crate::Result<()> { - let mut futures: FuturesUnordered<_> = addresses - .into_iter() - .map(|address| { - let connection_open_timeout = self.config.connection_open_timeout; - let dial_addresses = self.dial_addresses.clone(); - - async move { - WebSocketTransport::dial_peer(address, dial_addresses, connection_open_timeout) - .await - } - }) - .collect(); - - self.pending_raw_connections.push(Box::pin(async move { - while let Some(result) = futures.next().await { - match result { - Ok((address, stream)) => return Ok((connection_id, address, stream)), - Err(error) => tracing::debug!( - target: LOG_TARGET, - ?connection_id, - ?error, - "failed to open connection", - ), - } - } - - Err(connection_id) - })); - - Ok(()) - } - - fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> { - let (stream, address) = self - .opened_raw - .remove(&connection_id) - .ok_or(Error::ConnectionDoesntExist(connection_id))?; - - let peer = match address.iter().find(|protocol| std::matches!(protocol, Protocol::P2p(_))) { - Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?, - _ => return Err(Error::InvalidState), - }; - let yamux_config = self.config.yamux_config.clone(); - let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; - let max_write_buffer_size = self.config.noise_write_buffer_size; - let connection_open_timeout = self.config.connection_open_timeout; - let keypair = self.context.keypair.clone(); - - tracing::trace!( - target: LOG_TARGET, - ?peer, - ?connection_id, - ?address, - "negotiate connection", - ); - - self.pending_dials.insert(connection_id, address.clone()); - self.pending_connections.push(Box::pin(async move { - match tokio::time::timeout(connection_open_timeout, async move { - WebSocketConnection::negotiate_connection( - stream, - Some(peer), - Role::Dialer, - address, - connection_id, - keypair, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await - .map_err(|error| WebSocketError::new(error, Some(connection_id))) - }) - .await - { - Err(_) => Err(WebSocketError::new(Error::Timeout, Some(connection_id))), - Ok(Err(error)) => Err(error), - Ok(Ok(connection)) => Ok(connection), - } - })); - - Ok(()) - } - - fn cancel(&mut self, connection_id: ConnectionId) { - self.canceled.insert(connection_id); - } + fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> { + let yamux_config = self.config.yamux_config.clone(); + let keypair = self.context.keypair.clone(); + let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?; + let connection_open_timeout = self.config.connection_open_timeout; + let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; + let max_write_buffer_size = self.config.noise_write_buffer_size; + let dial_addresses = self.dial_addresses.clone(); + self.pending_dials.insert(connection_id, address.clone()); + + tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection"); + + self.pending_connections.push(Box::pin(async move { + match tokio::time::timeout(connection_open_timeout, async move { + let (_, stream) = WebSocketTransport::dial_peer( + address.clone(), + dial_addresses, + connection_open_timeout, + ) + .await + .map_err(|error| WebSocketError::new(error, Some(connection_id)))?; + + WebSocketConnection::open_connection( + connection_id, + keypair, + stream, + address, + peer, + ws_address, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await + .map_err(|error| WebSocketError::new(error, Some(connection_id))) + }) + .await + { + Err(_) => Err(WebSocketError::new(Error::Timeout, Some(connection_id))), + Ok(Err(error)) => Err(error), + Ok(Ok(result)) => Ok(result), + } + })); + + Ok(()) + } + + fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let context = self + .pending_open + .remove(&connection_id) + .ok_or(Error::ConnectionDoesntExist(connection_id))?; + let protocol_set = self.context.protocol_set(connection_id); + let bandwidth_sink = self.context.bandwidth_sink.clone(); + let substream_open_timeout = self.config.substream_open_timeout; + + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + "start connection", + ); + + self.context.executor.run(Box::pin(async move { + if let Err(error) = WebSocketConnection::new( + context, + protocol_set, + bandwidth_sink, + substream_open_timeout, + ) + .start() + .await + { + tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "connection exited with error", + ); + } + })); + + Ok(()) + } + + fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + self.canceled.insert(connection_id); + self.pending_open + .remove(&connection_id) + .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(())) + } + + fn open( + &mut self, + connection_id: ConnectionId, + addresses: Vec, + ) -> crate::Result<()> { + let mut futures: FuturesUnordered<_> = addresses + .into_iter() + .map(|address| { + let connection_open_timeout = self.config.connection_open_timeout; + let dial_addresses = self.dial_addresses.clone(); + + async move { + WebSocketTransport::dial_peer(address, dial_addresses, connection_open_timeout) + .await + } + }) + .collect(); + + self.pending_raw_connections.push(Box::pin(async move { + while let Some(result) = futures.next().await { + match result { + Ok((address, stream)) => return Ok((connection_id, address, stream)), + Err(error) => tracing::debug!( + target: LOG_TARGET, + ?connection_id, + ?error, + "failed to open connection", + ), + } + } + + Err(connection_id) + })); + + Ok(()) + } + + fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> { + let (stream, address) = self + .opened_raw + .remove(&connection_id) + .ok_or(Error::ConnectionDoesntExist(connection_id))?; + + let peer = match address.iter().find(|protocol| std::matches!(protocol, Protocol::P2p(_))) { + Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?, + _ => return Err(Error::InvalidState), + }; + let yamux_config = self.config.yamux_config.clone(); + let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; + let max_write_buffer_size = self.config.noise_write_buffer_size; + let connection_open_timeout = self.config.connection_open_timeout; + let keypair = self.context.keypair.clone(); + + tracing::trace!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?address, + "negotiate connection", + ); + + self.pending_dials.insert(connection_id, address.clone()); + self.pending_connections.push(Box::pin(async move { + match tokio::time::timeout(connection_open_timeout, async move { + WebSocketConnection::negotiate_connection( + stream, + Some(peer), + Role::Dialer, + address, + connection_id, + keypair, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await + .map_err(|error| WebSocketError::new(error, Some(connection_id))) + }) + .await + { + Err(_) => Err(WebSocketError::new(Error::Timeout, Some(connection_id))), + Ok(Err(error)) => Err(error), + Ok(Ok(connection)) => Ok(connection), + } + })); + + Ok(()) + } + + fn cancel(&mut self, connection_id: ConnectionId) { + self.canceled.insert(connection_id); + } } impl Stream for WebSocketTransport { - type Item = TransportEvent; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - while let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) { - match connection { - Err(_) => return Poll::Ready(None), - Ok((stream, address)) => { - let connection_id = self.context.next_connection_id(); - let keypair = self.context.keypair.clone(); - let yamux_config = self.config.yamux_config.clone(); - let connection_open_timeout = self.config.connection_open_timeout; - let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; - let max_write_buffer_size = self.config.noise_write_buffer_size; - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))); - - self.pending_connections.push(Box::pin(async move { - match tokio::time::timeout(connection_open_timeout, async move { - WebSocketConnection::accept_connection( - stream, - connection_id, - keypair, - address, - yamux_config, - max_read_ahead_factor, - max_write_buffer_size, - ) - .await - .map_err(|error| WebSocketError::new(error, None)) - }) - .await - { - Err(_) => Err(WebSocketError::new(Error::Timeout, None)), - Ok(Err(error)) => Err(error), - Ok(Ok(result)) => Ok(result), - } - })); - }, - } - } - - while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { - match result { - Ok((connection_id, address, stream)) => { - tracing::trace!( - target: LOG_TARGET, - ?connection_id, - ?address, - canceled = self.canceled.contains(&connection_id), - "connection opened", - ); - - if !self.canceled.remove(&connection_id) { - self.opened_raw.insert(connection_id, (stream, address.clone())); - - return Poll::Ready(Some(TransportEvent::ConnectionOpened { - connection_id, - address, - })); - } - }, - Err(connection_id) => - if !self.canceled.remove(&connection_id) { - return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id })); - }, - } - } - - while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) { - match connection { - Ok(connection) => { - let peer = connection.peer(); - let endpoint = connection.endpoint(); - self.pending_open.insert(connection.connection_id(), connection); - - return Poll::Ready(Some(TransportEvent::ConnectionEstablished { - peer, - endpoint, - })); - }, - Err(error) => match error.connection_id { - Some(connection_id) => match self.pending_dials.remove(&connection_id) { - Some(address) => - return Poll::Ready(Some(TransportEvent::DialFailure { - connection_id, - address, - error: error.error, - })), - None => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to establish connection") - }, - }, - None => { - tracing::debug!(target: LOG_TARGET, ?error, "failed to establish connection") - }, - }, - } - } - - Poll::Pending - } + type Item = TransportEvent; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) { + match connection { + Err(_) => return Poll::Ready(None), + Ok((stream, address)) => { + let connection_id = self.context.next_connection_id(); + let keypair = self.context.keypair.clone(); + let yamux_config = self.config.yamux_config.clone(); + let connection_open_timeout = self.config.connection_open_timeout; + let max_read_ahead_factor = self.config.noise_read_ahead_frame_count; + let max_write_buffer_size = self.config.noise_write_buffer_size; + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))); + + self.pending_connections.push(Box::pin(async move { + match tokio::time::timeout(connection_open_timeout, async move { + WebSocketConnection::accept_connection( + stream, + connection_id, + keypair, + address, + yamux_config, + max_read_ahead_factor, + max_write_buffer_size, + ) + .await + .map_err(|error| WebSocketError::new(error, None)) + }) + .await + { + Err(_) => Err(WebSocketError::new(Error::Timeout, None)), + Ok(Err(error)) => Err(error), + Ok(Ok(result)) => Ok(result), + } + })); + } + } + } + + while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) { + match result { + Ok((connection_id, address, stream)) => { + tracing::trace!( + target: LOG_TARGET, + ?connection_id, + ?address, + canceled = self.canceled.contains(&connection_id), + "connection opened", + ); + + if !self.canceled.remove(&connection_id) { + self.opened_raw.insert(connection_id, (stream, address.clone())); + + return Poll::Ready(Some(TransportEvent::ConnectionOpened { + connection_id, + address, + })); + } + } + Err(connection_id) => + if !self.canceled.remove(&connection_id) { + return Poll::Ready(Some(TransportEvent::OpenFailure { connection_id })); + }, + } + } + + while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) { + match connection { + Ok(connection) => { + let peer = connection.peer(); + let endpoint = connection.endpoint(); + self.pending_open.insert(connection.connection_id(), connection); + + return Poll::Ready(Some(TransportEvent::ConnectionEstablished { + peer, + endpoint, + })); + } + Err(error) => match error.connection_id { + Some(connection_id) => match self.pending_dials.remove(&connection_id) { + Some(address) => + return Poll::Ready(Some(TransportEvent::DialFailure { + connection_id, + address, + error: error.error, + })), + None => { + tracing::debug!(target: LOG_TARGET, ?error, "failed to establish connection") + } + }, + None => { + tracing::debug!(target: LOG_TARGET, ?error, "failed to establish connection") + } + }, + } + } + + Poll::Pending + } } diff --git a/src/transport/websocket/stream.rs b/src/transport/websocket/stream.rs index f409c640..2705c302 100644 --- a/src/transport/websocket/stream.rs +++ b/src/transport/websocket/stream.rs @@ -27,158 +27,158 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_tungstenite::{tungstenite::Message, WebSocketStream}; use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; // TODO: add tests /// Send state. enum State { - /// State is poisoned. - Poisoned, + /// State is poisoned. + Poisoned, - /// Sink is accepting input. - ReadyToSend, + /// Sink is accepting input. + ReadyToSend, - /// Sink is ready to send. - ReadyPending { to_write: Vec }, + /// Sink is ready to send. + ReadyPending { to_write: Vec }, - /// Flush is pending for the sink. - FlushPending, + /// Flush is pending for the sink. + FlushPending, } /// Buffered stream which implements `AsyncRead + AsyncWrite` pub(super) struct BufferedStream { - /// Write buffer. - write_buffer: Vec, + /// Write buffer. + write_buffer: Vec, - /// Write pointer. - write_ptr: usize, + /// Write pointer. + write_ptr: usize, - // Read buffer. - read_buffer: Option, + // Read buffer. + read_buffer: Option, - /// Underlying WebSocket stream. - stream: WebSocketStream, + /// Underlying WebSocket stream. + stream: WebSocketStream, - /// Read state. - state: State, + /// Read state. + state: State, } impl BufferedStream { - /// Create new [`BufferedStream`]. - pub(super) fn new(stream: WebSocketStream) -> Self { - Self { - write_buffer: Vec::with_capacity(2000), - read_buffer: None, - write_ptr: 0usize, - stream, - state: State::ReadyToSend, - } - } + /// Create new [`BufferedStream`]. + pub(super) fn new(stream: WebSocketStream) -> Self { + Self { + write_buffer: Vec::with_capacity(2000), + read_buffer: None, + write_ptr: 0usize, + stream, + state: State::ReadyToSend, + } + } } impl futures::AsyncWrite for BufferedStream { - fn poll_write( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.write_buffer.extend_from_slice(buf); - self.write_ptr += buf.len(); - - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.write_buffer.is_empty() { - return self - .stream - .poll_ready_unpin(cx) - .map_err(|_| std::io::ErrorKind::UnexpectedEof.into()); - } - - loop { - match std::mem::replace(&mut self.state, State::Poisoned) { - State::ReadyToSend => { - let message = self.write_buffer[..self.write_ptr].to_vec(); - self.state = State::ReadyPending { to_write: message }; - - match futures::ready!(self.stream.poll_ready_unpin(cx)) { - Ok(()) => continue, - Err(_error) => { - return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())); - }, - } - }, - State::ReadyPending { to_write } => { - match self.stream.start_send_unpin(Message::Binary(to_write.clone())) { - Ok(_) => { - self.state = State::FlushPending; - continue; - }, - Err(_error) => - return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), - } - }, - State::FlushPending => match futures::ready!(self.stream.poll_flush_unpin(cx)) { - Ok(_res) => { - // TODO: optimize - self.state = State::ReadyToSend; - self.write_ptr = 0; - self.write_buffer = Vec::with_capacity(2000); - return Poll::Ready(Ok(())); - }, - Err(_) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), - }, - State::Poisoned => - return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), - } - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match futures::ready!(self.stream.poll_close_unpin(cx)) { - Ok(_) => Poll::Ready(Ok(())), - Err(_) => return Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())), - } - } + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.write_buffer.extend_from_slice(buf); + self.write_ptr += buf.len(); + + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.write_buffer.is_empty() { + return self + .stream + .poll_ready_unpin(cx) + .map_err(|_| std::io::ErrorKind::UnexpectedEof.into()); + } + + loop { + match std::mem::replace(&mut self.state, State::Poisoned) { + State::ReadyToSend => { + let message = self.write_buffer[..self.write_ptr].to_vec(); + self.state = State::ReadyPending { to_write: message }; + + match futures::ready!(self.stream.poll_ready_unpin(cx)) { + Ok(()) => continue, + Err(_error) => { + return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())); + } + } + } + State::ReadyPending { to_write } => { + match self.stream.start_send_unpin(Message::Binary(to_write.clone())) { + Ok(_) => { + self.state = State::FlushPending; + continue; + } + Err(_error) => + return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), + } + } + State::FlushPending => match futures::ready!(self.stream.poll_flush_unpin(cx)) { + Ok(_res) => { + // TODO: optimize + self.state = State::ReadyToSend; + self.write_ptr = 0; + self.write_buffer = Vec::with_capacity(2000); + return Poll::Ready(Ok(())); + } + Err(_) => return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), + }, + State::Poisoned => + return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), + } + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match futures::ready!(self.stream.poll_close_unpin(cx)) { + Ok(_) => Poll::Ready(Ok(())), + Err(_) => return Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())), + } + } } impl futures::AsyncRead for BufferedStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - loop { - if self.read_buffer.is_none() { - match self.stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(chunk))) => match chunk { - Message::Binary(chunk) => self.read_buffer.replace(chunk.into()), - _event => return Poll::Ready(Err(std::io::ErrorKind::Unsupported.into())), - }, - Poll::Ready(Some(Err(_error))) => - return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), - Poll::Ready(None) => return Poll::Ready(Ok(0)), - Poll::Pending => return Poll::Pending, - }; - } - - let buffer = self.read_buffer.as_mut().expect("buffer to exist"); - let bytes_read = buf.len().min(buffer.len()); - let _orig_size = buffer.len(); - buf[..bytes_read].copy_from_slice(&buffer[..bytes_read]); - - buffer.advance(bytes_read); - - // TODO: this can't be correct - if !buffer.is_empty() || bytes_read != 0 { - return Poll::Ready(Ok(bytes_read.into())); - } else { - self.read_buffer.take(); - } - } - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + if self.read_buffer.is_none() { + match self.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(chunk))) => match chunk { + Message::Binary(chunk) => self.read_buffer.replace(chunk.into()), + _event => return Poll::Ready(Err(std::io::ErrorKind::Unsupported.into())), + }, + Poll::Ready(Some(Err(_error))) => + return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())), + Poll::Ready(None) => return Poll::Ready(Ok(0)), + Poll::Pending => return Poll::Pending, + }; + } + + let buffer = self.read_buffer.as_mut().expect("buffer to exist"); + let bytes_read = buf.len().min(buffer.len()); + let _orig_size = buffer.len(); + buf[..bytes_read].copy_from_slice(&buffer[..bytes_read]); + + buffer.advance(bytes_read); + + // TODO: this can't be correct + if !buffer.is_empty() || bytes_read != 0 { + return Poll::Ready(Ok(bytes_read.into())); + } else { + self.read_buffer.take(); + } + } + } } diff --git a/src/transport/websocket/substream.rs b/src/transport/websocket/substream.rs index 2875e9e4..427b8c87 100644 --- a/src/transport/websocket/substream.rs +++ b/src/transport/websocket/substream.rs @@ -24,75 +24,79 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::compat::Compat; use std::{ - io, - pin::Pin, - task::{Context, Poll}, + io, + pin::Pin, + task::{Context, Poll}, }; /// Substream that holds the inner substream provided by the transport /// and a permit which keeps the connection open. #[derive(Debug)] pub struct Substream { - /// Underlying socket. - io: Compat, + /// Underlying socket. + io: Compat, - /// Bandwidth sink. - bandwidth_sink: BandwidthSink, + /// Bandwidth sink. + bandwidth_sink: BandwidthSink, - /// Connection permit. - _permit: Permit, + /// Connection permit. + _permit: Permit, } impl Substream { - /// Create new [`Substream`]. - pub fn new( - io: Compat, - bandwidth_sink: BandwidthSink, - _permit: Permit, - ) -> Self { - Self { io, bandwidth_sink, _permit } - } + /// Create new [`Substream`]. + pub fn new( + io: Compat, + bandwidth_sink: BandwidthSink, + _permit: Permit, + ) -> Self { + Self { + io, + bandwidth_sink, + _permit, + } + } } impl AsyncRead for Substream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.io).poll_read(cx, buf)) { - Err(error) => Poll::Ready(Err(error)), - Ok(res) => { - self.bandwidth_sink.increase_inbound(buf.filled().len()); - Poll::Ready(Ok(res)) - }, - } - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.io).poll_read(cx, buf)) { + Err(error) => Poll::Ready(Err(error)), + Ok(res) => { + self.bandwidth_sink.increase_inbound(buf.filled().len()); + Poll::Ready(Ok(res)) + } + } + } } impl AsyncWrite for Substream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match futures::ready!(Pin::new(&mut self.io).poll_write(cx, buf)) { - Err(error) => Poll::Ready(Err(error)), - Ok(nwritten) => { - self.bandwidth_sink.increase_outbound(nwritten); - Poll::Ready(Ok(nwritten)) - }, - } - } + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match futures::ready!(Pin::new(&mut self.io).poll_write(cx, buf)) { + Err(error) => Poll::Ready(Err(error)), + Ok(nwritten) => { + self.bandwidth_sink.increase_outbound(nwritten); + Poll::Ready(Ok(nwritten)) + } + } + } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_flush(cx) - } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx) + } - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.io).poll_shutdown(cx) - } + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.io).poll_shutdown(cx) + } } diff --git a/src/types.rs b/src/types.rs index f75b3a09..5b125129 100644 --- a/src/types.rs +++ b/src/types.rs @@ -29,15 +29,15 @@ pub mod protocol; pub struct SubstreamId(usize); impl SubstreamId { - /// Create new [`SubstreamId`]. - pub fn new() -> Self { - SubstreamId(0usize) - } + /// Create new [`SubstreamId`]. + pub fn new() -> Self { + SubstreamId(0usize) + } - /// Get [`SubstreamId`] from a number that can be converted into a `usize`. - pub fn from>(value: T) -> Self { - SubstreamId(value.into()) - } + /// Get [`SubstreamId`] from a number that can be converted into a `usize`. + pub fn from>(value: T) -> Self { + SubstreamId(value.into()) + } } /// Request ID. @@ -45,10 +45,10 @@ impl SubstreamId { pub struct RequestId(usize); impl RequestId { - /// Get [`RequestId`] from a number that can be converted into a `usize`. - pub fn from>(value: T) -> Self { - RequestId(value.into()) - } + /// Get [`RequestId`] from a number that can be converted into a `usize`. + pub fn from>(value: T) -> Self { + RequestId(value.into()) + } } /// Connection ID. @@ -56,19 +56,19 @@ impl RequestId { pub struct ConnectionId(usize); impl ConnectionId { - /// Create new [`ConnectionId`]. - pub fn new() -> Self { - ConnectionId(0usize) - } + /// Create new [`ConnectionId`]. + pub fn new() -> Self { + ConnectionId(0usize) + } - /// Generate random `ConnectionId`. - pub fn random() -> Self { - ConnectionId(rand::thread_rng().gen::()) - } + /// Generate random `ConnectionId`. + pub fn random() -> Self { + ConnectionId(rand::thread_rng().gen::()) + } } impl From for ConnectionId { - fn from(value: usize) -> Self { - ConnectionId(value) - } + fn from(value: usize) -> Self { + ConnectionId(value) + } } diff --git a/src/types/protocol.rs b/src/types/protocol.rs index 8ae6766b..adbfe8b1 100644 --- a/src/types/protocol.rs +++ b/src/types/protocol.rs @@ -21,79 +21,79 @@ //! Protocol name. use std::{ - fmt::Display, - hash::{Hash, Hasher}, - sync::Arc, + fmt::Display, + hash::{Hash, Hasher}, + sync::Arc, }; /// Protocol name. #[derive(Debug, Clone)] pub enum ProtocolName { - Static(&'static str), - Allocated(Arc), + Static(&'static str), + Allocated(Arc), } impl From<&'static str> for ProtocolName { - fn from(protocol: &'static str) -> Self { - ProtocolName::Static(protocol) - } + fn from(protocol: &'static str) -> Self { + ProtocolName::Static(protocol) + } } impl Display for ProtocolName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Static(protocol) => protocol.fmt(f), - Self::Allocated(protocol) => protocol.fmt(f), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Static(protocol) => protocol.fmt(f), + Self::Allocated(protocol) => protocol.fmt(f), + } + } } impl From for ProtocolName { - fn from(protocol: String) -> Self { - ProtocolName::Allocated(Arc::from(protocol)) - } + fn from(protocol: String) -> Self { + ProtocolName::Allocated(Arc::from(protocol)) + } } impl From> for ProtocolName { - fn from(protocol: Arc) -> Self { - Self::Allocated(protocol) - } + fn from(protocol: Arc) -> Self { + Self::Allocated(protocol) + } } impl std::ops::Deref for ProtocolName { - type Target = str; + type Target = str; - fn deref(&self) -> &Self::Target { - match self { - Self::Static(protocol) => protocol, - Self::Allocated(protocol) => protocol, - } - } + fn deref(&self) -> &Self::Target { + match self { + Self::Static(protocol) => protocol, + Self::Allocated(protocol) => protocol, + } + } } impl Hash for ProtocolName { - fn hash(&self, state: &mut H) { - (self as &str).hash(state) - } + fn hash(&self, state: &mut H) { + (self as &str).hash(state) + } } impl PartialEq for ProtocolName { - fn eq(&self, other: &Self) -> bool { - (self as &str) == (other as &str) - } + fn eq(&self, other: &Self) -> bool { + (self as &str) == (other as &str) + } } impl Eq for ProtocolName {} #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn make_protocol() { - let protocol1 = ProtocolName::from(Arc::from(String::from("/protocol/1"))); - let protocol2 = ProtocolName::from("/protocol/1"); + #[test] + fn make_protocol() { + let protocol1 = ProtocolName::from(Arc::from(String::from("/protocol/1"))); + let protocol2 = ProtocolName::from("/protocol/1"); - assert_eq!(protocol1, protocol2); - } + assert_eq!(protocol1, protocol2); + } } diff --git a/src/yamux/chunks.rs b/src/yamux/chunks.rs index ec85b00f..19515dc7 100644 --- a/src/yamux/chunks.rs +++ b/src/yamux/chunks.rs @@ -17,40 +17,45 @@ use std::{collections::VecDeque, io}; /// [`Chunk`] elements. #[derive(Debug)] pub(crate) struct Chunks { - seq: VecDeque, - len: usize, + seq: VecDeque, + len: usize, } impl Chunks { - /// A new empty chunk list. - pub(crate) fn new() -> Self { - Chunks { seq: VecDeque::new(), len: 0 } - } + /// A new empty chunk list. + pub(crate) fn new() -> Self { + Chunks { + seq: VecDeque::new(), + len: 0, + } + } - /// The total length of bytes yet-to-be-read in all `Chunk`s. - pub(crate) fn len(&self) -> usize { - self.len - self.seq.front().map(|c| c.offset()).unwrap_or(0) - } + /// The total length of bytes yet-to-be-read in all `Chunk`s. + pub(crate) fn len(&self) -> usize { + self.len - self.seq.front().map(|c| c.offset()).unwrap_or(0) + } - /// Add another chunk of bytes to the end. - pub(crate) fn push(&mut self, x: Vec) { - self.len += x.len(); - if !x.is_empty() { - self.seq.push_back(Chunk { cursor: io::Cursor::new(x) }) - } - } + /// Add another chunk of bytes to the end. + pub(crate) fn push(&mut self, x: Vec) { + self.len += x.len(); + if !x.is_empty() { + self.seq.push_back(Chunk { + cursor: io::Cursor::new(x), + }) + } + } - /// Remove and return the first chunk. - pub(crate) fn pop(&mut self) -> Option { - let chunk = self.seq.pop_front(); - self.len -= chunk.as_ref().map(|c| c.len() + c.offset()).unwrap_or(0); - chunk - } + /// Remove and return the first chunk. + pub(crate) fn pop(&mut self) -> Option { + let chunk = self.seq.pop_front(); + self.len -= chunk.as_ref().map(|c| c.len() + c.offset()).unwrap_or(0); + chunk + } - /// Get a mutable reference to the first chunk. - pub(crate) fn front_mut(&mut self) -> Option<&mut Chunk> { - self.seq.front_mut() - } + /// Get a mutable reference to the first chunk. + pub(crate) fn front_mut(&mut self) -> Option<&mut Chunk> { + self.seq.front_mut() + } } /// A `Chunk` wraps a `std::io::Cursor>`. @@ -59,48 +64,48 @@ impl Chunks { /// vector can be consumed in steps. #[derive(Debug)] pub(crate) struct Chunk { - cursor: io::Cursor>, + cursor: io::Cursor>, } impl Chunk { - /// Is this chunk empty? - pub(crate) fn is_empty(&self) -> bool { - self.len() == 0 - } + /// Is this chunk empty? + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } - /// The remaining number of bytes in this `Chunk`. - pub(crate) fn len(&self) -> usize { - self.cursor.get_ref().len() - self.offset() - } + /// The remaining number of bytes in this `Chunk`. + pub(crate) fn len(&self) -> usize { + self.cursor.get_ref().len() - self.offset() + } - /// The sum of bytes that the cursor has been `advance`d over. - pub(crate) fn offset(&self) -> usize { - self.cursor.position() as usize - } + /// The sum of bytes that the cursor has been `advance`d over. + pub(crate) fn offset(&self) -> usize { + self.cursor.position() as usize + } - /// Move the cursor position by `amount` bytes. - /// - /// The `AsRef<[u8]>` impl of `Chunk` provides a byte-slice view - /// from the current position to the end. - pub(crate) fn advance(&mut self, amount: usize) { - assert!({ - // the new position must not exceed the vector's length - let pos = self.offset().checked_add(amount); - let max = self.cursor.get_ref().len(); - pos.is_some() && pos <= Some(max) - }); + /// Move the cursor position by `amount` bytes. + /// + /// The `AsRef<[u8]>` impl of `Chunk` provides a byte-slice view + /// from the current position to the end. + pub(crate) fn advance(&mut self, amount: usize) { + assert!({ + // the new position must not exceed the vector's length + let pos = self.offset().checked_add(amount); + let max = self.cursor.get_ref().len(); + pos.is_some() && pos <= Some(max) + }); - self.cursor.set_position(self.cursor.position() + amount as u64); - } + self.cursor.set_position(self.cursor.position() + amount as u64); + } - /// Consume `self` and return the inner vector. - pub(crate) fn into_vec(self) -> Vec { - self.cursor.into_inner() - } + /// Consume `self` and return the inner vector. + pub(crate) fn into_vec(self) -> Vec { + self.cursor.into_inner() + } } impl AsRef<[u8]> for Chunk { - fn as_ref(&self) -> &[u8] { - &self.cursor.get_ref()[self.offset()..] - } + fn as_ref(&self) -> &[u8] { + &self.cursor.get_ref()[self.offset()..] + } } diff --git a/src/yamux/connection.rs b/src/yamux/connection.rs index 9932d62e..6ae85009 100644 --- a/src/yamux/connection.rs +++ b/src/yamux/connection.rs @@ -85,31 +85,31 @@ mod closing; mod stream; use crate::yamux::{ - error::ConnectionError, - frame::{ - self, - header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID}, - Frame, - }, - tagged_stream::TaggedStream, - Config, Result, WindowUpdateMode, DEFAULT_CREDIT, MAX_ACK_BACKLOG, + error::ConnectionError, + frame::{ + self, + header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID}, + Frame, + }, + tagged_stream::TaggedStream, + Config, Result, WindowUpdateMode, DEFAULT_CREDIT, MAX_ACK_BACKLOG, }; use cleanup::Cleanup; use closing::Closing; use futures::{ - channel::mpsc, - future::Either, - prelude::*, - sink::SinkExt, - stream::{Fuse, SelectAll}, + channel::mpsc, + future::Either, + prelude::*, + sink::SinkExt, + stream::{Fuse, SelectAll}, }; use nohash_hasher::IntMap; use parking_lot::Mutex; use std::{ - collections::VecDeque, - fmt, - sync::Arc, - task::{Context, Poll, Waker}, + collections::VecDeque, + fmt, + sync::Arc, + task::{Context, Poll, Waker}, }; pub use stream::{Packet, State, Stream}; @@ -120,10 +120,10 @@ const LOG_TARGET: &str = "litep2p::yamux"; /// How the connection is used. #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] pub enum Mode { - /// Client to server connection. - Client, - /// Server to client connection. - Server, + /// Client to server connection. + Client, + /// Server to client connection. + Server, } /// The connection identifier. @@ -133,219 +133,221 @@ pub enum Mode { pub(crate) struct Id(u32); impl Id { - /// Create a random connection ID. - pub(crate) fn random() -> Self { - Id(rand::random()) - } + /// Create a random connection ID. + pub(crate) fn random() -> Self { + Id(rand::random()) + } } impl fmt::Debug for Id { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:08x}", self.0) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:08x}", self.0) + } } impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:08x}", self.0) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:08x}", self.0) + } } #[derive(Debug)] pub struct Connection { - inner: ConnectionState, + inner: ConnectionState, } impl Connection { - pub fn new(socket: T, cfg: Config, mode: Mode) -> Self { - Self { inner: ConnectionState::Active(Active::new(socket, cfg, mode)) } - } - - /// Poll for a new outbound stream. - /// - /// This function will fail if the current state does not allow opening new outbound streams. - pub fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) { - ConnectionState::Active(mut active) => match active.poll_new_outbound(cx) { - Poll::Ready(Ok(stream)) => { - self.inner = ConnectionState::Active(active); - return Poll::Ready(Ok(stream)); - }, - Poll::Pending => { - self.inner = ConnectionState::Active(active); - return Poll::Pending; - }, - Poll::Ready(Err(e)) => { - self.inner = ConnectionState::Cleanup(active.cleanup(e)); - continue; - }, - }, - ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx) { - Poll::Ready(Ok(())) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Err(ConnectionError::Closed)); - }, - Poll::Ready(Err(e)) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Err(e)); - }, - Poll::Pending => { - self.inner = ConnectionState::Closing(inner); - return Poll::Pending; - }, - }, - ConnectionState::Cleanup(mut inner) => match inner.poll_unpin(cx) { - Poll::Ready(e) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Err(e)); - }, - Poll::Pending => { - self.inner = ConnectionState::Cleanup(inner); - return Poll::Pending; - }, - }, - ConnectionState::Closed => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Err(ConnectionError::Closed)); - }, - ConnectionState::Poisoned => unreachable!(), - } - } - } - - /// Poll for the next inbound stream. - /// - /// If this function returns `None`, the underlying connection is closed. - pub fn poll_next_inbound(&mut self, cx: &mut Context<'_>) -> Poll>> { - loop { - match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) { - ConnectionState::Active(mut active) => match active.poll(cx) { - Poll::Ready(Ok(stream)) => { - self.inner = ConnectionState::Active(active); - return Poll::Ready(Some(Ok(stream))); - }, - Poll::Ready(Err(e)) => { - self.inner = ConnectionState::Cleanup(active.cleanup(e)); - continue; - }, - Poll::Pending => { - self.inner = ConnectionState::Active(active); - return Poll::Pending; - }, - }, - ConnectionState::Closing(mut closing) => match closing.poll_unpin(cx) { - Poll::Ready(Ok(())) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(None); - }, - Poll::Ready(Err(e)) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Some(Err(e))); - }, - Poll::Pending => { - self.inner = ConnectionState::Closing(closing); - return Poll::Pending; - }, - }, - ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) { - Poll::Ready(ConnectionError::Closed) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(None); - }, - Poll::Ready(other) => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Some(Err(other))); - }, - Poll::Pending => { - self.inner = ConnectionState::Cleanup(cleanup); - return Poll::Pending; - }, - }, - ConnectionState::Closed => { - self.inner = ConnectionState::Closed; - return Poll::Ready(None); - }, - ConnectionState::Poisoned => unreachable!(), - } - } - } - - /// Close the connection. - pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) { - ConnectionState::Active(active) => { - self.inner = ConnectionState::Closing(active.close()); - }, - ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx)? { - Poll::Ready(()) => { - self.inner = ConnectionState::Closed; - }, - Poll::Pending => { - self.inner = ConnectionState::Closing(inner); - return Poll::Pending; - }, - }, - ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) { - Poll::Ready(reason) => { - tracing::warn!(target: LOG_TARGET, "Failure while closing connection: {}", reason); - self.inner = ConnectionState::Closed; - return Poll::Ready(Ok(())); - }, - Poll::Pending => { - self.inner = ConnectionState::Cleanup(cleanup); - return Poll::Pending; - }, - }, - ConnectionState::Closed => { - self.inner = ConnectionState::Closed; - return Poll::Ready(Ok(())); - }, - ConnectionState::Poisoned => { - unreachable!() - }, - } - } - } + pub fn new(socket: T, cfg: Config, mode: Mode) -> Self { + Self { + inner: ConnectionState::Active(Active::new(socket, cfg, mode)), + } + } + + /// Poll for a new outbound stream. + /// + /// This function will fail if the current state does not allow opening new outbound streams. + pub fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) { + ConnectionState::Active(mut active) => match active.poll_new_outbound(cx) { + Poll::Ready(Ok(stream)) => { + self.inner = ConnectionState::Active(active); + return Poll::Ready(Ok(stream)); + } + Poll::Pending => { + self.inner = ConnectionState::Active(active); + return Poll::Pending; + } + Poll::Ready(Err(e)) => { + self.inner = ConnectionState::Cleanup(active.cleanup(e)); + continue; + } + }, + ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx) { + Poll::Ready(Ok(())) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Err(ConnectionError::Closed)); + } + Poll::Ready(Err(e)) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Err(e)); + } + Poll::Pending => { + self.inner = ConnectionState::Closing(inner); + return Poll::Pending; + } + }, + ConnectionState::Cleanup(mut inner) => match inner.poll_unpin(cx) { + Poll::Ready(e) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Err(e)); + } + Poll::Pending => { + self.inner = ConnectionState::Cleanup(inner); + return Poll::Pending; + } + }, + ConnectionState::Closed => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Err(ConnectionError::Closed)); + } + ConnectionState::Poisoned => unreachable!(), + } + } + } + + /// Poll for the next inbound stream. + /// + /// If this function returns `None`, the underlying connection is closed. + pub fn poll_next_inbound(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) { + ConnectionState::Active(mut active) => match active.poll(cx) { + Poll::Ready(Ok(stream)) => { + self.inner = ConnectionState::Active(active); + return Poll::Ready(Some(Ok(stream))); + } + Poll::Ready(Err(e)) => { + self.inner = ConnectionState::Cleanup(active.cleanup(e)); + continue; + } + Poll::Pending => { + self.inner = ConnectionState::Active(active); + return Poll::Pending; + } + }, + ConnectionState::Closing(mut closing) => match closing.poll_unpin(cx) { + Poll::Ready(Ok(())) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(None); + } + Poll::Ready(Err(e)) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Some(Err(e))); + } + Poll::Pending => { + self.inner = ConnectionState::Closing(closing); + return Poll::Pending; + } + }, + ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) { + Poll::Ready(ConnectionError::Closed) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(None); + } + Poll::Ready(other) => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Some(Err(other))); + } + Poll::Pending => { + self.inner = ConnectionState::Cleanup(cleanup); + return Poll::Pending; + } + }, + ConnectionState::Closed => { + self.inner = ConnectionState::Closed; + return Poll::Ready(None); + } + ConnectionState::Poisoned => unreachable!(), + } + } + } + + /// Close the connection. + pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match std::mem::replace(&mut self.inner, ConnectionState::Poisoned) { + ConnectionState::Active(active) => { + self.inner = ConnectionState::Closing(active.close()); + } + ConnectionState::Closing(mut inner) => match inner.poll_unpin(cx)? { + Poll::Ready(()) => { + self.inner = ConnectionState::Closed; + } + Poll::Pending => { + self.inner = ConnectionState::Closing(inner); + return Poll::Pending; + } + }, + ConnectionState::Cleanup(mut cleanup) => match cleanup.poll_unpin(cx) { + Poll::Ready(reason) => { + tracing::warn!(target: LOG_TARGET, "Failure while closing connection: {}", reason); + self.inner = ConnectionState::Closed; + return Poll::Ready(Ok(())); + } + Poll::Pending => { + self.inner = ConnectionState::Cleanup(cleanup); + return Poll::Pending; + } + }, + ConnectionState::Closed => { + self.inner = ConnectionState::Closed; + return Poll::Ready(Ok(())); + } + ConnectionState::Poisoned => { + unreachable!() + } + } + } + } } impl Drop for Connection { - fn drop(&mut self) { - match &mut self.inner { - ConnectionState::Active(active) => active.drop_all_streams(), - ConnectionState::Closing(_) => {}, - ConnectionState::Cleanup(_) => {}, - ConnectionState::Closed => {}, - ConnectionState::Poisoned => {}, - } - } + fn drop(&mut self) { + match &mut self.inner { + ConnectionState::Active(active) => active.drop_all_streams(), + ConnectionState::Closing(_) => {} + ConnectionState::Cleanup(_) => {} + ConnectionState::Closed => {} + ConnectionState::Poisoned => {} + } + } } enum ConnectionState { - /// The connection is alive and healthy. - Active(Active), - /// Our user requested to shutdown the connection, we are working on it. - Closing(Closing), - /// An error occurred and we are cleaning up our resources. - Cleanup(Cleanup), - /// The connection is closed. - Closed, - /// Something went wrong during our state transitions. Should never happen unless there is a - /// bug. - Poisoned, + /// The connection is alive and healthy. + Active(Active), + /// Our user requested to shutdown the connection, we are working on it. + Closing(Closing), + /// An error occurred and we are cleaning up our resources. + Cleanup(Cleanup), + /// The connection is closed. + Closed, + /// Something went wrong during our state transitions. Should never happen unless there is a + /// bug. + Poisoned, } impl fmt::Debug for ConnectionState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ConnectionState::Active(_) => write!(f, "Active"), - ConnectionState::Closing(_) => write!(f, "Closing"), - ConnectionState::Cleanup(_) => write!(f, "Cleanup"), - ConnectionState::Closed => write!(f, "Closed"), - ConnectionState::Poisoned => write!(f, "Poisoned"), - } - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ConnectionState::Active(_) => write!(f, "Active"), + ConnectionState::Closing(_) => write!(f, "Closing"), + ConnectionState::Cleanup(_) => write!(f, "Cleanup"), + ConnectionState::Closed => write!(f, "Closed"), + ConnectionState::Poisoned => write!(f, "Poisoned"), + } + } } /// A Yamux connection object. @@ -354,622 +356,628 @@ impl fmt::Debug for ConnectionState { /// [`Connection::poll_next_inbound`] method which must be called repeatedly /// until `Ok(None)` signals EOF or an error is encountered. struct Active { - id: Id, - mode: Mode, - config: Arc, - socket: Fuse>, - next_id: u32, - - streams: IntMap>>, - stream_receivers: SelectAll>>, - no_streams_waker: Option, - - pending_frames: VecDeque>, - new_outbound_stream_waker: Option, + id: Id, + mode: Mode, + config: Arc, + socket: Fuse>, + next_id: u32, + + streams: IntMap>>, + stream_receivers: SelectAll>>, + no_streams_waker: Option, + + pending_frames: VecDeque>, + new_outbound_stream_waker: Option, } /// `Stream` to `Connection` commands. #[derive(Debug)] pub(crate) enum StreamCommand { - /// A new frame should be sent to the remote. - SendFrame(Frame>), - /// Close a stream. - CloseStream { ack: bool }, + /// A new frame should be sent to the remote. + SendFrame(Frame>), + /// Close a stream. + CloseStream { ack: bool }, } /// Possible actions as a result of incoming frame handling. #[derive(Debug)] enum Action { - /// Nothing to be done. - None, - /// A new stream has been opened by the remote. - New(Stream, Option>), - /// A window update should be sent to the remote. - Update(Frame), - /// A ping should be answered. - Ping(Frame), - /// A stream should be reset. - Reset(Frame), - /// The connection should be terminated. - Terminate(Frame), + /// Nothing to be done. + None, + /// A new stream has been opened by the remote. + New(Stream, Option>), + /// A window update should be sent to the remote. + Update(Frame), + /// A ping should be answered. + Ping(Frame), + /// A stream should be reset. + Reset(Frame), + /// The connection should be terminated. + Terminate(Frame), } impl fmt::Debug for Active { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Connection") - .field("id", &self.id) - .field("mode", &self.mode) - .field("streams", &self.streams.len()) - .field("next_id", &self.next_id) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Connection") + .field("id", &self.id) + .field("mode", &self.mode) + .field("streams", &self.streams.len()) + .field("next_id", &self.next_id) + .finish() + } } impl fmt::Display for Active { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "(Connection {} {:?} (streams {}))", self.id, self.mode, self.streams.len()) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "(Connection {} {:?} (streams {}))", + self.id, + self.mode, + self.streams.len() + ) + } } impl Active { - /// Create a new `Connection` from the given I/O resource. - fn new(socket: T, cfg: Config, mode: Mode) -> Self { - let id = Id::random(); - tracing::debug!(target: LOG_TARGET, "new connection: {} ({:?})", id, mode); - let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse(); - Active { - id, - mode, - config: Arc::new(cfg), - socket, - streams: IntMap::default(), - stream_receivers: SelectAll::default(), - no_streams_waker: None, - next_id: match mode { - Mode::Client => 1, - Mode::Server => 2, - }, - pending_frames: VecDeque::default(), - new_outbound_stream_waker: None, - } - } - - /// Gracefully close the connection to the remote. - fn close(self) -> Closing { - Closing::new(self.stream_receivers, self.pending_frames, self.socket) - } - - /// Cleanup all our resources. - /// - /// This should be called in the context of an unrecoverable error on the connection. - fn cleanup(mut self, error: ConnectionError) -> Cleanup { - self.drop_all_streams(); - - Cleanup::new(self.stream_receivers, error) - } - - fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { - loop { - if self.socket.poll_ready_unpin(cx).is_ready() { - if let Some(frame) = self.pending_frames.pop_front() { - self.socket.start_send_unpin(frame)?; - continue; - } - } - - match self.socket.poll_flush_unpin(cx)? { - Poll::Ready(()) => {}, - Poll::Pending => {}, - } - - match self.stream_receivers.poll_next_unpin(cx) { - Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { - self.on_send_frame(frame.into()); - continue; - }, - Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => { - self.on_close_stream(id, ack); - continue; - }, - Poll::Ready(Some((id, None))) => { - self.on_drop_stream(id); - continue; - }, - Poll::Ready(None) => { - self.no_streams_waker = Some(cx.waker().clone()); - }, - Poll::Pending => {}, - } - - match self.socket.poll_next_unpin(cx) { - Poll::Ready(Some(frame)) => { - if let Some(stream) = self.on_frame(frame?)? { - return Poll::Ready(Ok(stream)); - } - continue; - }, - Poll::Ready(None) => { - return Poll::Ready(Err(ConnectionError::Closed)); - }, - Poll::Pending => {}, - } - - // If we make it this far, at least one of the above must have registered a waker. - return Poll::Pending; - } - } - - fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.streams.len() >= self.config.max_num_streams { - tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id); - return Poll::Ready(Err(ConnectionError::TooManyStreams)); - } - - if self.ack_backlog() >= MAX_ACK_BACKLOG { - tracing::debug!(target: LOG_TARGET, "{MAX_ACK_BACKLOG} streams waiting for ACK, registering task for wake-up until remote acknowledges at least one stream"); - self.new_outbound_stream_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - - tracing::trace!(target: LOG_TARGET, "{}: creating new outbound stream", self.id); - - let id = self.next_stream_id()?; - let extra_credit = self.config.receive_window - DEFAULT_CREDIT; - - if extra_credit > 0 { - let mut frame = Frame::window_update(id, extra_credit); - frame.header_mut().syn(); - tracing::trace!(target: LOG_TARGET, "{}/{}: sending initial {}", self.id, id, frame.header()); - self.pending_frames.push_back(frame.into()); - } - - let mut stream = self.make_new_outbound_stream(id, self.config.receive_window); - - if extra_credit == 0 { - stream.set_flag(stream::Flag::Syn) - } - - tracing::debug!(target: LOG_TARGET, "{}: new outbound {} of {}", self.id, stream, self); - self.streams.insert(id, stream.clone_shared()); - - Poll::Ready(Ok(stream)) - } - - fn on_send_frame(&mut self, frame: Frame>) { - tracing::trace!(target: LOG_TARGET, - "{}/{}: sending: {}", - self.id, - frame.header().stream_id(), - frame.header() - ); - self.pending_frames.push_back(frame.into()); - } - - fn on_close_stream(&mut self, id: StreamId, ack: bool) { - tracing::trace!(target: LOG_TARGET, "{}/{}: sending close", self.id, id); - self.pending_frames.push_back(Frame::close_stream(id, ack).into()); - } - - fn on_drop_stream(&mut self, stream_id: StreamId) { - let s = self.streams.remove(&stream_id).expect("stream not found"); - - tracing::trace!(target: LOG_TARGET, "{}: removing dropped stream {}", self.id, stream_id); - let frame = { - let mut shared = s.lock(); - let frame = match shared.update_state(self.id, stream_id, State::Closed) { - // The stream was dropped without calling `poll_close`. - // We reset the stream to inform the remote of the closure. - State::Open { .. } => { - let mut header = Header::data(stream_id, 0); - header.rst(); - Some(Frame::new(header)) - }, - // The stream was dropped without calling `poll_close`. - // We have already received a FIN from remote and send one - // back which closes the stream for good. - State::RecvClosed => { - let mut header = Header::data(stream_id, 0); - header.fin(); - Some(Frame::new(header)) - }, - // The stream was properly closed. We already sent our FIN frame. - // The remote may be out of credit though and blocked on - // writing more data. We may need to reset the stream. - State::SendClosed => { - if self.config.window_update_mode == WindowUpdateMode::OnRead && - shared.window == 0 - { - // The remote may be waiting for a window update - // which we will never send, so reset the stream now. - let mut header = Header::data(stream_id, 0); - header.rst(); - Some(Frame::new(header)) - } else { - // The remote has either still credit or will be given more - // (due to an enqueued window update or because the update - // mode is `OnReceive`) or we already have inbound frames in - // the socket buffer which will be processed later. In any - // case we will reply with an RST in `Connection::on_data` - // because the stream will no longer be known. - None - } - }, - // The stream was properly closed. We already have sent our FIN frame. The - // remote end has already done so in the past. - State::Closed => None, - }; - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - frame - }; - if let Some(f) = frame { - tracing::trace!(target: LOG_TARGET, "{}/{}: sending: {}", self.id, stream_id, f.header()); - self.pending_frames.push_back(f.into()); - } - } - - /// Process the result of reading from the socket. - /// - /// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed - /// and return a corresponding error, which terminates the connection. - /// Otherwise we process the frame and potentially return a new `Stream` - /// if one was opened by the remote. - fn on_frame(&mut self, frame: Frame<()>) -> Result> { - tracing::trace!(target: LOG_TARGET, "{}: received: {}", self.id, frame.header()); - - if frame.header().flags().contains(header::ACK) { - let id = frame.header().stream_id(); - if let Some(stream) = self.streams.get(&id) { - stream.lock().update_state(self.id, id, State::Open { acknowledged: true }); - } - if let Some(waker) = self.new_outbound_stream_waker.take() { - waker.wake(); - } - } - - let action = match frame.header().tag() { - Tag::Data => self.on_data(frame.into_data()), - Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()), - Tag::Ping => self.on_ping(&frame.into_ping()), - Tag::GoAway => return Err(ConnectionError::Closed), - }; - match action { - Action::None => {}, - Action::New(stream, update) => { - tracing::trace!(target: LOG_TARGET, "{}: new inbound {} of {}", self.id, stream, self); - if let Some(f) = update { - tracing::trace!(target: LOG_TARGET, "{}/{}: sending update", self.id, f.header().stream_id()); - self.pending_frames.push_back(f.into()); - } - return Ok(Some(stream)); - }, - Action::Update(f) => { - tracing::trace!(target: LOG_TARGET, "{}: sending update: {:?}", self.id, f.header()); - self.pending_frames.push_back(f.into()); - }, - Action::Ping(f) => { - tracing::trace!(target: LOG_TARGET, "{}/{}: pong", self.id, f.header().stream_id()); - self.pending_frames.push_back(f.into()); - }, - Action::Reset(f) => { - tracing::trace!(target: LOG_TARGET, "{}/{}: sending reset", self.id, f.header().stream_id()); - self.pending_frames.push_back(f.into()); - }, - Action::Terminate(f) => { - tracing::trace!(target: LOG_TARGET, "{}: sending term", self.id); - self.pending_frames.push_back(f.into()); - }, - } - - Ok(None) - } - - fn on_data(&mut self, frame: Frame) -> Action { - let stream_id = frame.header().stream_id(); - - if frame.header().flags().contains(header::RST) { - // stream reset - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - } - return Action::None; - } - - let is_finish = frame.header().flags().contains(header::FIN); // half-close - - if frame.header().flags().contains(header::SYN) { - // new stream - if !self.is_valid_remote_id(stream_id, Tag::Data) { - tracing::error!(target: LOG_TARGET, "{}: invalid stream id {}", self.id, stream_id); - return Action::Terminate(Frame::protocol_error()); - } - if frame.body().len() > DEFAULT_CREDIT as usize { - tracing::error!(target: LOG_TARGET, - "{}/{}: 1st body of stream exceeds default credit", - self.id, - stream_id - ); - return Action::Terminate(Frame::protocol_error()); - } - if self.streams.contains_key(&stream_id) { - tracing::error!(target: LOG_TARGET, "{}/{}: stream already exists", self.id, stream_id); - return Action::Terminate(Frame::protocol_error()); - } - if self.streams.len() == self.config.max_num_streams { - tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id); - return Action::Terminate(Frame::internal_error()); - } - let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); - let mut window_update = None; - { - let mut shared = stream.shared(); - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - } - shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); - - if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { - if let Some(credit) = shared.next_window_update() { - shared.window += credit; - let mut frame = Frame::window_update(stream_id, credit); - frame.header_mut().ack(); - window_update = Some(frame) - } - } - } - if window_update.is_none() { - stream.set_flag(stream::Flag::Ack) - } - self.streams.insert(stream_id, stream.clone_shared()); - return Action::New(stream, window_update); - } - - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - if frame.body().len() > shared.window as usize { - tracing::error!(target: LOG_TARGET, - "{}/{}: frame body larger than window of stream", - self.id, - stream_id - ); - return Action::Terminate(Frame::protocol_error()); - } - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - } - let max_buffer_size = self.config.max_buffer_size; - if shared.buffer.len() >= max_buffer_size { - tracing::error!(target: LOG_TARGET, - "{}/{}: buffer of stream grows beyond limit", - self.id, - stream_id - ); - let mut header = Header::data(stream_id, 0); - header.rst(); - return Action::Reset(Frame::new(header)); - } - shared.window = shared.window.saturating_sub(frame.body_len()); - shared.buffer.push(frame.into_body()); - if let Some(w) = shared.reader.take() { - w.wake() - } - if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { - if let Some(credit) = shared.next_window_update() { - shared.window += credit; - let frame = Frame::window_update(stream_id, credit); - return Action::Update(frame); - } - } - } else { - tracing::trace!(target: LOG_TARGET, - "{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}", - self.id, - stream_id, - frame - ); - // We do not consider this a protocol violation and thus do not send a stream reset - // because we may still be processing pending `StreamCommand`s of this stream that were - // sent before it has been dropped and "garbage collected". Such a stream reset would - // interfere with the frames that still need to be sent, causing premature stream - // termination for the remote. - // - // See https://github.com/paritytech/yamux/issues/110 for details. - } - - Action::None - } - - fn on_window_update(&mut self, frame: &Frame) -> Action { - let stream_id = frame.header().stream_id(); - - if frame.header().flags().contains(header::RST) { - // stream reset - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.update_state(self.id, stream_id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - } - return Action::None; - } - - let is_finish = frame.header().flags().contains(header::FIN); // half-close - - if frame.header().flags().contains(header::SYN) { - // new stream - if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) { - tracing::error!(target: LOG_TARGET, "{}: invalid stream id {}", self.id, stream_id); - return Action::Terminate(Frame::protocol_error()); - } - if self.streams.contains_key(&stream_id) { - tracing::error!(target: LOG_TARGET, "{}/{}: stream already exists", self.id, stream_id); - return Action::Terminate(Frame::protocol_error()); - } - if self.streams.len() == self.config.max_num_streams { - tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id); - return Action::Terminate(Frame::protocol_error()); - } - - let credit = frame.header().credit() + DEFAULT_CREDIT; - let mut stream = self.make_new_inbound_stream(stream_id, credit); - stream.set_flag(stream::Flag::Ack); - - if is_finish { - stream.shared().update_state(self.id, stream_id, State::RecvClosed); - } - self.streams.insert(stream_id, stream.clone_shared()); - return Action::New(stream, None); - } - - if let Some(s) = self.streams.get_mut(&stream_id) { - let mut shared = s.lock(); - shared.credit += frame.header().credit(); - if is_finish { - shared.update_state(self.id, stream_id, State::RecvClosed); - } - if let Some(w) = shared.writer.take() { - w.wake() - } - } else { - tracing::trace!(target: LOG_TARGET, - "{}/{}: window update for unknown stream, possibly dropped earlier: {:?}", - self.id, - stream_id, - frame - ); - // We do not consider this a protocol violation and thus do not send a stream reset - // because we may still be processing pending `StreamCommand`s of this stream that were - // sent before it has been dropped and "garbage collected". Such a stream reset would - // interfere with the frames that still need to be sent, causing premature stream - // termination for the remote. - // - // See https://github.com/paritytech/yamux/issues/110 for details. - } - - Action::None - } - - fn on_ping(&mut self, frame: &Frame) -> Action { - let stream_id = frame.header().stream_id(); - if frame.header().flags().contains(header::ACK) { - // pong - return Action::None; - } - if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { - let mut hdr = Header::ping(frame.header().nonce()); - hdr.ack(); - return Action::Ping(Frame::new(hdr)); - } - tracing::trace!(target: LOG_TARGET, - "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", - self.id, - stream_id, - frame - ); - // We do not consider this a protocol violation and thus do not send a stream reset because - // we may still be processing pending `StreamCommand`s of this stream that were sent before - // it has been dropped and "garbage collected". Such a stream reset would interfere with the - // frames that still need to be sent, causing premature stream termination for the remote. - // - // See https://github.com/paritytech/yamux/issues/110 for details. - - Action::None - } - - fn make_new_inbound_stream(&mut self, id: StreamId, credit: u32) -> Stream { - let config = self.config.clone(); - - let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number. - self.stream_receivers.push(TaggedStream::new(id, receiver)); - if let Some(waker) = self.no_streams_waker.take() { - waker.wake(); - } - - Stream::new_inbound(id, self.id, config, credit, sender) - } - - fn make_new_outbound_stream(&mut self, id: StreamId, window: u32) -> Stream { - let config = self.config.clone(); - - let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number. - self.stream_receivers.push(TaggedStream::new(id, receiver)); - if let Some(waker) = self.no_streams_waker.take() { - waker.wake(); - } - - Stream::new_outbound(id, self.id, config, window, sender) - } - - fn next_stream_id(&mut self) -> Result { - let proposed = StreamId::new(self.next_id); - self.next_id = self.next_id.checked_add(2).ok_or(ConnectionError::NoMoreStreamIds)?; - match self.mode { - Mode::Client => assert!(proposed.is_client()), - Mode::Server => assert!(proposed.is_server()), - } - Ok(proposed) - } - - /// The ACK backlog is defined as the number of outbound streams that have not yet been - /// acknowledged. - fn ack_backlog(&mut self) -> usize { - self.streams - .iter() - // Whether this is an outbound stream. - // - // Clients use odd IDs and servers use even IDs. - // A stream is outbound if: - // - // - Its ID is odd and we are the client. - // - Its ID is even and we are the server. - .filter(|(id, _)| match self.mode { - Mode::Client => id.is_client(), - Mode::Server => id.is_server(), - }) - .filter(|(_, s)| s.lock().is_pending_ack()) - .count() - } - - // Check if the given stream ID is valid w.r.t. the provided tag and our connection mode. - fn is_valid_remote_id(&self, id: StreamId, tag: Tag) -> bool { - if tag == Tag::Ping || tag == Tag::GoAway { - return id.is_session(); - } - match self.mode { - Mode::Client => id.is_server(), - Mode::Server => id.is_client(), - } - } + /// Create a new `Connection` from the given I/O resource. + fn new(socket: T, cfg: Config, mode: Mode) -> Self { + let id = Id::random(); + tracing::debug!(target: LOG_TARGET, "new connection: {} ({:?})", id, mode); + let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse(); + Active { + id, + mode, + config: Arc::new(cfg), + socket, + streams: IntMap::default(), + stream_receivers: SelectAll::default(), + no_streams_waker: None, + next_id: match mode { + Mode::Client => 1, + Mode::Server => 2, + }, + pending_frames: VecDeque::default(), + new_outbound_stream_waker: None, + } + } + + /// Gracefully close the connection to the remote. + fn close(self) -> Closing { + Closing::new(self.stream_receivers, self.pending_frames, self.socket) + } + + /// Cleanup all our resources. + /// + /// This should be called in the context of an unrecoverable error on the connection. + fn cleanup(mut self, error: ConnectionError) -> Cleanup { + self.drop_all_streams(); + + Cleanup::new(self.stream_receivers, error) + } + + fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + if self.socket.poll_ready_unpin(cx).is_ready() { + if let Some(frame) = self.pending_frames.pop_front() { + self.socket.start_send_unpin(frame)?; + continue; + } + } + + match self.socket.poll_flush_unpin(cx)? { + Poll::Ready(()) => {} + Poll::Pending => {} + } + + match self.stream_receivers.poll_next_unpin(cx) { + Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { + self.on_send_frame(frame.into()); + continue; + } + Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => { + self.on_close_stream(id, ack); + continue; + } + Poll::Ready(Some((id, None))) => { + self.on_drop_stream(id); + continue; + } + Poll::Ready(None) => { + self.no_streams_waker = Some(cx.waker().clone()); + } + Poll::Pending => {} + } + + match self.socket.poll_next_unpin(cx) { + Poll::Ready(Some(frame)) => { + if let Some(stream) = self.on_frame(frame?)? { + return Poll::Ready(Ok(stream)); + } + continue; + } + Poll::Ready(None) => { + return Poll::Ready(Err(ConnectionError::Closed)); + } + Poll::Pending => {} + } + + // If we make it this far, at least one of the above must have registered a waker. + return Poll::Pending; + } + } + + fn poll_new_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.streams.len() >= self.config.max_num_streams { + tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id); + return Poll::Ready(Err(ConnectionError::TooManyStreams)); + } + + if self.ack_backlog() >= MAX_ACK_BACKLOG { + tracing::debug!(target: LOG_TARGET, "{MAX_ACK_BACKLOG} streams waiting for ACK, registering task for wake-up until remote acknowledges at least one stream"); + self.new_outbound_stream_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + tracing::trace!(target: LOG_TARGET, "{}: creating new outbound stream", self.id); + + let id = self.next_stream_id()?; + let extra_credit = self.config.receive_window - DEFAULT_CREDIT; + + if extra_credit > 0 { + let mut frame = Frame::window_update(id, extra_credit); + frame.header_mut().syn(); + tracing::trace!(target: LOG_TARGET, "{}/{}: sending initial {}", self.id, id, frame.header()); + self.pending_frames.push_back(frame.into()); + } + + let mut stream = self.make_new_outbound_stream(id, self.config.receive_window); + + if extra_credit == 0 { + stream.set_flag(stream::Flag::Syn) + } + + tracing::debug!(target: LOG_TARGET, "{}: new outbound {} of {}", self.id, stream, self); + self.streams.insert(id, stream.clone_shared()); + + Poll::Ready(Ok(stream)) + } + + fn on_send_frame(&mut self, frame: Frame>) { + tracing::trace!(target: LOG_TARGET, + "{}/{}: sending: {}", + self.id, + frame.header().stream_id(), + frame.header() + ); + self.pending_frames.push_back(frame.into()); + } + + fn on_close_stream(&mut self, id: StreamId, ack: bool) { + tracing::trace!(target: LOG_TARGET, "{}/{}: sending close", self.id, id); + self.pending_frames.push_back(Frame::close_stream(id, ack).into()); + } + + fn on_drop_stream(&mut self, stream_id: StreamId) { + let s = self.streams.remove(&stream_id).expect("stream not found"); + + tracing::trace!(target: LOG_TARGET, "{}: removing dropped stream {}", self.id, stream_id); + let frame = { + let mut shared = s.lock(); + let frame = match shared.update_state(self.id, stream_id, State::Closed) { + // The stream was dropped without calling `poll_close`. + // We reset the stream to inform the remote of the closure. + State::Open { .. } => { + let mut header = Header::data(stream_id, 0); + header.rst(); + Some(Frame::new(header)) + } + // The stream was dropped without calling `poll_close`. + // We have already received a FIN from remote and send one + // back which closes the stream for good. + State::RecvClosed => { + let mut header = Header::data(stream_id, 0); + header.fin(); + Some(Frame::new(header)) + } + // The stream was properly closed. We already sent our FIN frame. + // The remote may be out of credit though and blocked on + // writing more data. We may need to reset the stream. + State::SendClosed => { + if self.config.window_update_mode == WindowUpdateMode::OnRead + && shared.window == 0 + { + // The remote may be waiting for a window update + // which we will never send, so reset the stream now. + let mut header = Header::data(stream_id, 0); + header.rst(); + Some(Frame::new(header)) + } else { + // The remote has either still credit or will be given more + // (due to an enqueued window update or because the update + // mode is `OnReceive`) or we already have inbound frames in + // the socket buffer which will be processed later. In any + // case we will reply with an RST in `Connection::on_data` + // because the stream will no longer be known. + None + } + } + // The stream was properly closed. We already have sent our FIN frame. The + // remote end has already done so in the past. + State::Closed => None, + }; + if let Some(w) = shared.reader.take() { + w.wake() + } + if let Some(w) = shared.writer.take() { + w.wake() + } + frame + }; + if let Some(f) = frame { + tracing::trace!(target: LOG_TARGET, "{}/{}: sending: {}", self.id, stream_id, f.header()); + self.pending_frames.push_back(f.into()); + } + } + + /// Process the result of reading from the socket. + /// + /// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed + /// and return a corresponding error, which terminates the connection. + /// Otherwise we process the frame and potentially return a new `Stream` + /// if one was opened by the remote. + fn on_frame(&mut self, frame: Frame<()>) -> Result> { + tracing::trace!(target: LOG_TARGET, "{}: received: {}", self.id, frame.header()); + + if frame.header().flags().contains(header::ACK) { + let id = frame.header().stream_id(); + if let Some(stream) = self.streams.get(&id) { + stream.lock().update_state(self.id, id, State::Open { acknowledged: true }); + } + if let Some(waker) = self.new_outbound_stream_waker.take() { + waker.wake(); + } + } + + let action = match frame.header().tag() { + Tag::Data => self.on_data(frame.into_data()), + Tag::WindowUpdate => self.on_window_update(&frame.into_window_update()), + Tag::Ping => self.on_ping(&frame.into_ping()), + Tag::GoAway => return Err(ConnectionError::Closed), + }; + match action { + Action::None => {} + Action::New(stream, update) => { + tracing::trace!(target: LOG_TARGET, "{}: new inbound {} of {}", self.id, stream, self); + if let Some(f) = update { + tracing::trace!(target: LOG_TARGET, "{}/{}: sending update", self.id, f.header().stream_id()); + self.pending_frames.push_back(f.into()); + } + return Ok(Some(stream)); + } + Action::Update(f) => { + tracing::trace!(target: LOG_TARGET, "{}: sending update: {:?}", self.id, f.header()); + self.pending_frames.push_back(f.into()); + } + Action::Ping(f) => { + tracing::trace!(target: LOG_TARGET, "{}/{}: pong", self.id, f.header().stream_id()); + self.pending_frames.push_back(f.into()); + } + Action::Reset(f) => { + tracing::trace!(target: LOG_TARGET, "{}/{}: sending reset", self.id, f.header().stream_id()); + self.pending_frames.push_back(f.into()); + } + Action::Terminate(f) => { + tracing::trace!(target: LOG_TARGET, "{}: sending term", self.id); + self.pending_frames.push_back(f.into()); + } + } + + Ok(None) + } + + fn on_data(&mut self, frame: Frame) -> Action { + let stream_id = frame.header().stream_id(); + + if frame.header().flags().contains(header::RST) { + // stream reset + if let Some(s) = self.streams.get_mut(&stream_id) { + let mut shared = s.lock(); + shared.update_state(self.id, stream_id, State::Closed); + if let Some(w) = shared.reader.take() { + w.wake() + } + if let Some(w) = shared.writer.take() { + w.wake() + } + } + return Action::None; + } + + let is_finish = frame.header().flags().contains(header::FIN); // half-close + + if frame.header().flags().contains(header::SYN) { + // new stream + if !self.is_valid_remote_id(stream_id, Tag::Data) { + tracing::error!(target: LOG_TARGET, "{}: invalid stream id {}", self.id, stream_id); + return Action::Terminate(Frame::protocol_error()); + } + if frame.body().len() > DEFAULT_CREDIT as usize { + tracing::error!(target: LOG_TARGET, + "{}/{}: 1st body of stream exceeds default credit", + self.id, + stream_id + ); + return Action::Terminate(Frame::protocol_error()); + } + if self.streams.contains_key(&stream_id) { + tracing::error!(target: LOG_TARGET, "{}/{}: stream already exists", self.id, stream_id); + return Action::Terminate(Frame::protocol_error()); + } + if self.streams.len() == self.config.max_num_streams { + tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id); + return Action::Terminate(Frame::internal_error()); + } + let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT); + let mut window_update = None; + { + let mut shared = stream.shared(); + if is_finish { + shared.update_state(self.id, stream_id, State::RecvClosed); + } + shared.window = shared.window.saturating_sub(frame.body_len()); + shared.buffer.push(frame.into_body()); + + if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { + if let Some(credit) = shared.next_window_update() { + shared.window += credit; + let mut frame = Frame::window_update(stream_id, credit); + frame.header_mut().ack(); + window_update = Some(frame) + } + } + } + if window_update.is_none() { + stream.set_flag(stream::Flag::Ack) + } + self.streams.insert(stream_id, stream.clone_shared()); + return Action::New(stream, window_update); + } + + if let Some(s) = self.streams.get_mut(&stream_id) { + let mut shared = s.lock(); + if frame.body().len() > shared.window as usize { + tracing::error!(target: LOG_TARGET, + "{}/{}: frame body larger than window of stream", + self.id, + stream_id + ); + return Action::Terminate(Frame::protocol_error()); + } + if is_finish { + shared.update_state(self.id, stream_id, State::RecvClosed); + } + let max_buffer_size = self.config.max_buffer_size; + if shared.buffer.len() >= max_buffer_size { + tracing::error!(target: LOG_TARGET, + "{}/{}: buffer of stream grows beyond limit", + self.id, + stream_id + ); + let mut header = Header::data(stream_id, 0); + header.rst(); + return Action::Reset(Frame::new(header)); + } + shared.window = shared.window.saturating_sub(frame.body_len()); + shared.buffer.push(frame.into_body()); + if let Some(w) = shared.reader.take() { + w.wake() + } + if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { + if let Some(credit) = shared.next_window_update() { + shared.window += credit; + let frame = Frame::window_update(stream_id, credit); + return Action::Update(frame); + } + } + } else { + tracing::trace!(target: LOG_TARGET, + "{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}", + self.id, + stream_id, + frame + ); + // We do not consider this a protocol violation and thus do not send a stream reset + // because we may still be processing pending `StreamCommand`s of this stream that were + // sent before it has been dropped and "garbage collected". Such a stream reset would + // interfere with the frames that still need to be sent, causing premature stream + // termination for the remote. + // + // See https://github.com/paritytech/yamux/issues/110 for details. + } + + Action::None + } + + fn on_window_update(&mut self, frame: &Frame) -> Action { + let stream_id = frame.header().stream_id(); + + if frame.header().flags().contains(header::RST) { + // stream reset + if let Some(s) = self.streams.get_mut(&stream_id) { + let mut shared = s.lock(); + shared.update_state(self.id, stream_id, State::Closed); + if let Some(w) = shared.reader.take() { + w.wake() + } + if let Some(w) = shared.writer.take() { + w.wake() + } + } + return Action::None; + } + + let is_finish = frame.header().flags().contains(header::FIN); // half-close + + if frame.header().flags().contains(header::SYN) { + // new stream + if !self.is_valid_remote_id(stream_id, Tag::WindowUpdate) { + tracing::error!(target: LOG_TARGET, "{}: invalid stream id {}", self.id, stream_id); + return Action::Terminate(Frame::protocol_error()); + } + if self.streams.contains_key(&stream_id) { + tracing::error!(target: LOG_TARGET, "{}/{}: stream already exists", self.id, stream_id); + return Action::Terminate(Frame::protocol_error()); + } + if self.streams.len() == self.config.max_num_streams { + tracing::error!(target: LOG_TARGET, "{}: maximum number of streams reached", self.id); + return Action::Terminate(Frame::protocol_error()); + } + + let credit = frame.header().credit() + DEFAULT_CREDIT; + let mut stream = self.make_new_inbound_stream(stream_id, credit); + stream.set_flag(stream::Flag::Ack); + + if is_finish { + stream.shared().update_state(self.id, stream_id, State::RecvClosed); + } + self.streams.insert(stream_id, stream.clone_shared()); + return Action::New(stream, None); + } + + if let Some(s) = self.streams.get_mut(&stream_id) { + let mut shared = s.lock(); + shared.credit += frame.header().credit(); + if is_finish { + shared.update_state(self.id, stream_id, State::RecvClosed); + } + if let Some(w) = shared.writer.take() { + w.wake() + } + } else { + tracing::trace!(target: LOG_TARGET, + "{}/{}: window update for unknown stream, possibly dropped earlier: {:?}", + self.id, + stream_id, + frame + ); + // We do not consider this a protocol violation and thus do not send a stream reset + // because we may still be processing pending `StreamCommand`s of this stream that were + // sent before it has been dropped and "garbage collected". Such a stream reset would + // interfere with the frames that still need to be sent, causing premature stream + // termination for the remote. + // + // See https://github.com/paritytech/yamux/issues/110 for details. + } + + Action::None + } + + fn on_ping(&mut self, frame: &Frame) -> Action { + let stream_id = frame.header().stream_id(); + if frame.header().flags().contains(header::ACK) { + // pong + return Action::None; + } + if stream_id == CONNECTION_ID || self.streams.contains_key(&stream_id) { + let mut hdr = Header::ping(frame.header().nonce()); + hdr.ack(); + return Action::Ping(Frame::new(hdr)); + } + tracing::trace!(target: LOG_TARGET, + "{}/{}: ping for unknown stream, possibly dropped earlier: {:?}", + self.id, + stream_id, + frame + ); + // We do not consider this a protocol violation and thus do not send a stream reset because + // we may still be processing pending `StreamCommand`s of this stream that were sent before + // it has been dropped and "garbage collected". Such a stream reset would interfere with the + // frames that still need to be sent, causing premature stream termination for the remote. + // + // See https://github.com/paritytech/yamux/issues/110 for details. + + Action::None + } + + fn make_new_inbound_stream(&mut self, id: StreamId, credit: u32) -> Stream { + let config = self.config.clone(); + + let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number. + self.stream_receivers.push(TaggedStream::new(id, receiver)); + if let Some(waker) = self.no_streams_waker.take() { + waker.wake(); + } + + Stream::new_inbound(id, self.id, config, credit, sender) + } + + fn make_new_outbound_stream(&mut self, id: StreamId, window: u32) -> Stream { + let config = self.config.clone(); + + let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number. + self.stream_receivers.push(TaggedStream::new(id, receiver)); + if let Some(waker) = self.no_streams_waker.take() { + waker.wake(); + } + + Stream::new_outbound(id, self.id, config, window, sender) + } + + fn next_stream_id(&mut self) -> Result { + let proposed = StreamId::new(self.next_id); + self.next_id = self.next_id.checked_add(2).ok_or(ConnectionError::NoMoreStreamIds)?; + match self.mode { + Mode::Client => assert!(proposed.is_client()), + Mode::Server => assert!(proposed.is_server()), + } + Ok(proposed) + } + + /// The ACK backlog is defined as the number of outbound streams that have not yet been + /// acknowledged. + fn ack_backlog(&mut self) -> usize { + self.streams + .iter() + // Whether this is an outbound stream. + // + // Clients use odd IDs and servers use even IDs. + // A stream is outbound if: + // + // - Its ID is odd and we are the client. + // - Its ID is even and we are the server. + .filter(|(id, _)| match self.mode { + Mode::Client => id.is_client(), + Mode::Server => id.is_server(), + }) + .filter(|(_, s)| s.lock().is_pending_ack()) + .count() + } + + // Check if the given stream ID is valid w.r.t. the provided tag and our connection mode. + fn is_valid_remote_id(&self, id: StreamId, tag: Tag) -> bool { + if tag == Tag::Ping || tag == Tag::GoAway { + return id.is_session(); + } + match self.mode { + Mode::Client => id.is_server(), + Mode::Server => id.is_client(), + } + } } impl Active { - /// Close and drop all `Stream`s and wake any pending `Waker`s. - fn drop_all_streams(&mut self) { - for (id, s) in self.streams.drain() { - let mut shared = s.lock(); - shared.update_state(self.id, id, State::Closed); - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - } - } + /// Close and drop all `Stream`s and wake any pending `Waker`s. + fn drop_all_streams(&mut self) { + for (id, s) in self.streams.drain() { + let mut shared = s.lock(); + shared.update_state(self.id, id, State::Closed); + if let Some(w) = shared.reader.take() { + w.wake() + } + if let Some(w) = shared.writer.take() { + w.wake() + } + } + } } diff --git a/src/yamux/connection/cleanup.rs b/src/yamux/connection/cleanup.rs index 0f39bb8a..d1be682c 100644 --- a/src/yamux/connection/cleanup.rs +++ b/src/yamux/connection/cleanup.rs @@ -1,60 +1,64 @@ use crate::yamux::{ - connection::StreamCommand, tagged_stream::TaggedStream, ConnectionError, StreamId, + connection::StreamCommand, tagged_stream::TaggedStream, ConnectionError, StreamId, }; use futures::{channel::mpsc, stream::SelectAll, StreamExt}; use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, + future::Future, + pin::Pin, + task::{Context, Poll}, }; /// A [`Future`] that cleans up resources in case of an error. #[must_use] pub struct Cleanup { - state: State, - stream_receivers: SelectAll>>, - error: Option, + state: State, + stream_receivers: SelectAll>>, + error: Option, } impl Cleanup { - pub(crate) fn new( - stream_receivers: SelectAll>>, - error: ConnectionError, - ) -> Self { - Self { state: State::ClosingStreamReceiver, stream_receivers, error: Some(error) } - } + pub(crate) fn new( + stream_receivers: SelectAll>>, + error: ConnectionError, + ) -> Self { + Self { + state: State::ClosingStreamReceiver, + stream_receivers, + error: Some(error), + } + } } impl Future for Cleanup { - type Output = ConnectionError; + type Output = ConnectionError; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - loop { - match this.state { - State::ClosingStreamReceiver => { - for stream in this.stream_receivers.iter_mut() { - stream.inner_mut().close(); - } - this.state = State::DrainingStreamReceiver; - }, - State::DrainingStreamReceiver => match this.stream_receivers.poll_next_unpin(cx) { - Poll::Ready(Some(cmd)) => { - drop(cmd); - }, - Poll::Ready(None) | Poll::Pending => - return Poll::Ready( - this.error.take().expect("to not be called after completion"), - ), - }, - } - } - } + loop { + match this.state { + State::ClosingStreamReceiver => { + for stream in this.stream_receivers.iter_mut() { + stream.inner_mut().close(); + } + this.state = State::DrainingStreamReceiver; + } + State::DrainingStreamReceiver => match this.stream_receivers.poll_next_unpin(cx) { + Poll::Ready(Some(cmd)) => { + drop(cmd); + } + Poll::Ready(None) | Poll::Pending => + return Poll::Ready( + this.error.take().expect("to not be called after completion"), + ), + }, + } + } + } } #[allow(clippy::enum_variant_names)] enum State { - ClosingStreamReceiver, - DrainingStreamReceiver, + ClosingStreamReceiver, + DrainingStreamReceiver, } diff --git a/src/yamux/connection/closing.rs b/src/yamux/connection/closing.rs index 6c8f34df..6ae541e4 100644 --- a/src/yamux/connection/closing.rs +++ b/src/yamux/connection/closing.rs @@ -1,96 +1,101 @@ use crate::yamux::{ - connection::StreamCommand, frame, frame::Frame, tagged_stream::TaggedStream, Result, StreamId, + connection::StreamCommand, frame, frame::Frame, tagged_stream::TaggedStream, Result, StreamId, }; use futures::{ - channel::mpsc, - ready, - stream::{Fuse, SelectAll}, - AsyncRead, AsyncWrite, SinkExt, StreamExt, + channel::mpsc, + ready, + stream::{Fuse, SelectAll}, + AsyncRead, AsyncWrite, SinkExt, StreamExt, }; use std::{ - collections::VecDeque, - future::Future, - pin::Pin, - task::{Context, Poll}, + collections::VecDeque, + future::Future, + pin::Pin, + task::{Context, Poll}, }; /// A [`Future`] that gracefully closes the yamux connection. #[must_use] pub struct Closing { - state: State, - stream_receivers: SelectAll>>, - pending_frames: VecDeque>, - socket: Fuse>, + state: State, + stream_receivers: SelectAll>>, + pending_frames: VecDeque>, + socket: Fuse>, } impl Closing where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, { - pub(crate) fn new( - stream_receivers: SelectAll>>, - pending_frames: VecDeque>, - socket: Fuse>, - ) -> Self { - Self { state: State::ClosingStreamReceiver, stream_receivers, pending_frames, socket } - } + pub(crate) fn new( + stream_receivers: SelectAll>>, + pending_frames: VecDeque>, + socket: Fuse>, + ) -> Self { + Self { + state: State::ClosingStreamReceiver, + stream_receivers, + pending_frames, + socket, + } + } } impl Future for Closing where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, { - type Output = Result<()>; + type Output = Result<()>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); - loop { - match this.state { - State::ClosingStreamReceiver => { - for stream in this.stream_receivers.iter_mut() { - stream.inner_mut().close(); - } - this.state = State::DrainingStreamReceiver; - }, + loop { + match this.state { + State::ClosingStreamReceiver => { + for stream in this.stream_receivers.iter_mut() { + stream.inner_mut().close(); + } + this.state = State::DrainingStreamReceiver; + } - State::DrainingStreamReceiver => { - match this.stream_receivers.poll_next_unpin(cx) { - Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => - this.pending_frames.push_back(frame.into()), - Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => { - this.pending_frames.push_back(Frame::close_stream(id, ack).into()); - }, - Poll::Ready(Some((_, None))) => {}, - Poll::Pending | Poll::Ready(None) => { - // No more frames from streams, append `Term` frame and flush them all. - this.pending_frames.push_back(Frame::term().into()); - this.state = State::FlushingPendingFrames; - continue; - }, - } - }, - State::FlushingPendingFrames => { - ready!(this.socket.poll_ready_unpin(cx))?; + State::DrainingStreamReceiver => { + match this.stream_receivers.poll_next_unpin(cx) { + Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => + this.pending_frames.push_back(frame.into()), + Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => { + this.pending_frames.push_back(Frame::close_stream(id, ack).into()); + } + Poll::Ready(Some((_, None))) => {} + Poll::Pending | Poll::Ready(None) => { + // No more frames from streams, append `Term` frame and flush them all. + this.pending_frames.push_back(Frame::term().into()); + this.state = State::FlushingPendingFrames; + continue; + } + } + } + State::FlushingPendingFrames => { + ready!(this.socket.poll_ready_unpin(cx))?; - match this.pending_frames.pop_front() { - Some(frame) => this.socket.start_send_unpin(frame)?, - None => this.state = State::ClosingSocket, - } - }, - State::ClosingSocket => { - ready!(this.socket.poll_close_unpin(cx))?; + match this.pending_frames.pop_front() { + Some(frame) => this.socket.start_send_unpin(frame)?, + None => this.state = State::ClosingSocket, + } + } + State::ClosingSocket => { + ready!(this.socket.poll_close_unpin(cx))?; - return Poll::Ready(Ok(())); - }, - } - } - } + return Poll::Ready(Ok(())); + } + } + } + } } enum State { - ClosingStreamReceiver, - DrainingStreamReceiver, - FlushingPendingFrames, - ClosingSocket, + ClosingStreamReceiver, + DrainingStreamReceiver, + FlushingPendingFrames, + ClosingSocket, } diff --git a/src/yamux/connection/stream.rs b/src/yamux/connection/stream.rs index 8aad3946..03e3fa39 100644 --- a/src/yamux/connection/stream.rs +++ b/src/yamux/connection/stream.rs @@ -9,27 +9,27 @@ // at https://opensource.org/licenses/MIT. use crate::yamux::{ - chunks::Chunks, - connection::{self, StreamCommand}, - frame::{ - header::{Data, Header, StreamId, WindowUpdate, ACK}, - Frame, - }, - Config, WindowUpdateMode, DEFAULT_CREDIT, + chunks::Chunks, + connection::{self, StreamCommand}, + frame::{ + header::{Data, Header, StreamId, WindowUpdate, ACK}, + Frame, + }, + Config, WindowUpdateMode, DEFAULT_CREDIT, }; use futures::{ - channel::mpsc, - future::Either, - io::{AsyncRead, AsyncWrite}, - ready, SinkExt, + channel::mpsc, + future::Either, + io::{AsyncRead, AsyncWrite}, + ready, SinkExt, }; use parking_lot::{Mutex, MutexGuard}; use std::{ - convert::TryInto, - fmt, io, - pin::Pin, - sync::Arc, - task::{Context, Poll, Waker}, + convert::TryInto, + fmt, io, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, }; /// Logging target for the file. @@ -38,48 +38,48 @@ const LOG_TARGET: &str = "litep2p::yamux"; /// The state of a Yamux stream. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum State { - /// Open bidirectionally. - Open { - /// Whether the stream is acknowledged. - /// - /// For outbound streams, this tracks whether the remote has acknowledged our stream. - /// For inbound streams, this tracks whether we have acknowledged the stream to the remote. - /// - /// This starts out with `false` and is set to `true` when we receive or send an `ACK` flag - /// for this stream. We may also directly transition: - /// - from `Open` to `RecvClosed` if the remote immediately sends `FIN`. - /// - from `Open` to `Closed` if the remote immediately sends `RST`. - acknowledged: bool, - }, - /// Open for incoming messages. - SendClosed, - /// Open for outgoing messages. - RecvClosed, - /// Closed (terminal state). - Closed, + /// Open bidirectionally. + Open { + /// Whether the stream is acknowledged. + /// + /// For outbound streams, this tracks whether the remote has acknowledged our stream. + /// For inbound streams, this tracks whether we have acknowledged the stream to the remote. + /// + /// This starts out with `false` and is set to `true` when we receive or send an `ACK` flag + /// for this stream. We may also directly transition: + /// - from `Open` to `RecvClosed` if the remote immediately sends `FIN`. + /// - from `Open` to `Closed` if the remote immediately sends `RST`. + acknowledged: bool, + }, + /// Open for incoming messages. + SendClosed, + /// Open for outgoing messages. + RecvClosed, + /// Closed (terminal state). + Closed, } impl State { - /// Can we receive messages over this stream? - pub fn can_read(self) -> bool { - !matches!(self, State::RecvClosed | State::Closed) - } - - /// Can we send messages over this stream? - pub fn can_write(self) -> bool { - !matches!(self, State::SendClosed | State::Closed) - } + /// Can we receive messages over this stream? + pub fn can_read(self) -> bool { + !matches!(self, State::RecvClosed | State::Closed) + } + + /// Can we send messages over this stream? + pub fn can_write(self) -> bool { + !matches!(self, State::SendClosed | State::Closed) + } } /// Indicate if a flag still needs to be set on an outbound header. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub(crate) enum Flag { - /// No flag needs to be set. - None, - /// The stream was opened lazily, so set the initial SYN flag. - Syn, - /// The stream still needs acknowledgement, so set the ACK flag. - Ack, + /// No flag needs to be set. + None, + /// The stream was opened lazily, so set the initial SYN flag. + Syn, + /// The stream still needs acknowledgement, so set the ACK flag. + Ack, } /// A multiplexed Yamux stream. @@ -90,140 +90,140 @@ pub(crate) enum Flag { /// [`Stream`] implements [`AsyncRead`] and [`AsyncWrite`] and also /// [`futures::stream::Stream`]. pub struct Stream { - id: StreamId, - conn: connection::Id, - config: Arc, - sender: mpsc::Sender, - flag: Flag, - shared: Arc>, + id: StreamId, + conn: connection::Id, + config: Arc, + sender: mpsc::Sender, + flag: Flag, + shared: Arc>, } impl fmt::Debug for Stream { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Stream") - .field("id", &self.id.val()) - .field("connection", &self.conn) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Stream") + .field("id", &self.id.val()) + .field("connection", &self.conn) + .finish() + } } impl fmt::Display for Stream { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "(Stream {}/{})", self.conn, self.id.val()) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "(Stream {}/{})", self.conn, self.id.val()) + } } impl Stream { - pub(crate) fn new_inbound( - id: StreamId, - conn: connection::Id, - config: Arc, - credit: u32, - sender: mpsc::Sender, - ) -> Self { - Self { - id, - conn, - config: config.clone(), - sender, - flag: Flag::None, - shared: Arc::new(Mutex::new(Shared::new(DEFAULT_CREDIT, credit, config))), - } - } - - pub(crate) fn new_outbound( - id: StreamId, - conn: connection::Id, - config: Arc, - window: u32, - sender: mpsc::Sender, - ) -> Self { - Self { - id, - conn, - config: config.clone(), - sender, - flag: Flag::None, - shared: Arc::new(Mutex::new(Shared::new(window, DEFAULT_CREDIT, config))), - } - } - - /// Get this stream's identifier. - pub fn id(&self) -> StreamId { - self.id - } - - pub fn is_write_closed(&self) -> bool { - matches!(self.shared().state(), State::SendClosed) - } - - pub fn is_closed(&self) -> bool { - matches!(self.shared().state(), State::Closed) - } - - /// Whether we are still waiting for the remote to acknowledge this stream. - pub fn is_pending_ack(&self) -> bool { - self.shared().is_pending_ack() - } - - /// Set the flag that should be set on the next outbound frame header. - pub(crate) fn set_flag(&mut self, flag: Flag) { - self.flag = flag - } - - pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { - self.shared.lock() - } - - pub(crate) fn clone_shared(&self) -> Arc> { - self.shared.clone() - } - - fn write_zero_err(&self) -> io::Error { - let msg = format!("{}/{}: connection is closed", self.conn, self.id); - io::Error::new(io::ErrorKind::WriteZero, msg) - } - - /// Set ACK or SYN flag if necessary. - fn add_flag(&mut self, header: &mut Header>) { - match self.flag { - Flag::None => (), - Flag::Syn => { - header.syn(); - self.flag = Flag::None - }, - Flag::Ack => { - header.ack(); - self.flag = Flag::None - }, - } - } - - /// Send new credit to the sending side via a window update message if - /// permitted. - fn send_window_update(&mut self, cx: &mut Context) -> Poll> { - // When using [`WindowUpdateMode::OnReceive`] window update messages are - // send early on data receival (see [`crate::Connection::on_frame`]). - if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { - return Poll::Ready(Ok(())); - } - - let mut shared = self.shared.lock(); - - if let Some(credit) = shared.next_window_update() { - ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); - - shared.window += credit; - drop(shared); - - let mut frame = Frame::window_update(self.id, credit).right(); - self.add_flag(frame.header_mut()); - let cmd = StreamCommand::SendFrame(frame); - self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; - } - - Poll::Ready(Ok(())) - } + pub(crate) fn new_inbound( + id: StreamId, + conn: connection::Id, + config: Arc, + credit: u32, + sender: mpsc::Sender, + ) -> Self { + Self { + id, + conn, + config: config.clone(), + sender, + flag: Flag::None, + shared: Arc::new(Mutex::new(Shared::new(DEFAULT_CREDIT, credit, config))), + } + } + + pub(crate) fn new_outbound( + id: StreamId, + conn: connection::Id, + config: Arc, + window: u32, + sender: mpsc::Sender, + ) -> Self { + Self { + id, + conn, + config: config.clone(), + sender, + flag: Flag::None, + shared: Arc::new(Mutex::new(Shared::new(window, DEFAULT_CREDIT, config))), + } + } + + /// Get this stream's identifier. + pub fn id(&self) -> StreamId { + self.id + } + + pub fn is_write_closed(&self) -> bool { + matches!(self.shared().state(), State::SendClosed) + } + + pub fn is_closed(&self) -> bool { + matches!(self.shared().state(), State::Closed) + } + + /// Whether we are still waiting for the remote to acknowledge this stream. + pub fn is_pending_ack(&self) -> bool { + self.shared().is_pending_ack() + } + + /// Set the flag that should be set on the next outbound frame header. + pub(crate) fn set_flag(&mut self, flag: Flag) { + self.flag = flag + } + + pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { + self.shared.lock() + } + + pub(crate) fn clone_shared(&self) -> Arc> { + self.shared.clone() + } + + fn write_zero_err(&self) -> io::Error { + let msg = format!("{}/{}: connection is closed", self.conn, self.id); + io::Error::new(io::ErrorKind::WriteZero, msg) + } + + /// Set ACK or SYN flag if necessary. + fn add_flag(&mut self, header: &mut Header>) { + match self.flag { + Flag::None => (), + Flag::Syn => { + header.syn(); + self.flag = Flag::None + } + Flag::Ack => { + header.ack(); + self.flag = Flag::None + } + } + } + + /// Send new credit to the sending side via a window update message if + /// permitted. + fn send_window_update(&mut self, cx: &mut Context) -> Poll> { + // When using [`WindowUpdateMode::OnReceive`] window update messages are + // send early on data receival (see [`crate::Connection::on_frame`]). + if matches!(self.config.window_update_mode, WindowUpdateMode::OnReceive) { + return Poll::Ready(Ok(())); + } + + let mut shared = self.shared.lock(); + + if let Some(credit) = shared.next_window_update() { + ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); + + shared.window += credit; + drop(shared); + + let mut frame = Frame::window_update(self.id, credit).right(); + self.add_flag(frame.header_mut()); + let cmd = StreamCommand::SendFrame(frame); + self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; + } + + Poll::Ready(Ok(())) + } } /// Byte data produced by the [`futures::stream::Stream`] impl of [`Stream`]. @@ -231,287 +231,294 @@ impl Stream { pub struct Packet(Vec); impl AsRef<[u8]> for Packet { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } } impl futures::stream::Stream for Stream { - type Item = io::Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - if !self.config.read_after_close && self.sender.is_closed() { - return Poll::Ready(None); - } - - match self.send_window_update(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - // Continue reading buffered data even though sending a window update blocked. - Poll::Pending => {}, - } - - let mut shared = self.shared(); - - if let Some(bytes) = shared.buffer.pop() { - let off = bytes.offset(); - let mut vec = bytes.into_vec(); - if off != 0 { - // This should generally not happen when the stream is used only as - // a `futures::stream::Stream` since the whole point of this impl is - // to consume chunks atomically. It may perhaps happen when mixing - // this impl and the `AsyncRead` one. - tracing::debug!( - target: LOG_TARGET, - "{}/{}: chunk has been partially consumed", - self.conn, - self.id - ); - vec = vec.split_off(off) - } - return Poll::Ready(Some(Ok(Packet(vec)))); - } - - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state().can_read() { - tracing::debug!(target: LOG_TARGET, "{}/{}: eof", self.conn, self.id); - return Poll::Ready(None); // stream has been reset - } - - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); - - Poll::Pending - } + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if !self.config.read_after_close && self.sender.is_closed() { + return Poll::Ready(None); + } + + match self.send_window_update(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + // Continue reading buffered data even though sending a window update blocked. + Poll::Pending => {} + } + + let mut shared = self.shared(); + + if let Some(bytes) = shared.buffer.pop() { + let off = bytes.offset(); + let mut vec = bytes.into_vec(); + if off != 0 { + // This should generally not happen when the stream is used only as + // a `futures::stream::Stream` since the whole point of this impl is + // to consume chunks atomically. It may perhaps happen when mixing + // this impl and the `AsyncRead` one. + tracing::debug!( + target: LOG_TARGET, + "{}/{}: chunk has been partially consumed", + self.conn, + self.id + ); + vec = vec.split_off(off) + } + return Poll::Ready(Some(Ok(Packet(vec)))); + } + + // Buffer is empty, let's check if we can expect to read more data. + if !shared.state().can_read() { + tracing::debug!(target: LOG_TARGET, "{}/{}: eof", self.conn, self.id); + return Poll::Ready(None); // stream has been reset + } + + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + shared.reader = Some(cx.waker().clone()); + + Poll::Pending + } } // Like the `futures::stream::Stream` impl above, but copies bytes into the // provided mutable slice. impl AsyncRead for Stream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { - if !self.config.read_after_close && self.sender.is_closed() { - return Poll::Ready(Ok(0)); - } - - match self.send_window_update(cx) { - Poll::Ready(Ok(())) => {}, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - // Continue reading buffered data even though sending a window update blocked. - Poll::Pending => {}, - } - - // Copy data from stream buffer. - let mut shared = self.shared(); - let mut n = 0; - while let Some(chunk) = shared.buffer.front_mut() { - if chunk.is_empty() { - shared.buffer.pop(); - continue; - } - let k = std::cmp::min(chunk.len(), buf.len() - n); - buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); - n += k; - chunk.advance(k); - if n == buf.len() { - break; - } - } - - if n > 0 { - tracing::trace!(target: LOG_TARGET,"{}/{}: read {} bytes", self.conn, self.id, n); - return Poll::Ready(Ok(n)); - } - - // Buffer is empty, let's check if we can expect to read more data. - if !shared.state().can_read() { - tracing::debug!(target: LOG_TARGET,"{}/{}: eof", self.conn, self.id); - return Poll::Ready(Ok(0)); // stream has been reset - } - - // Since we have no more data at this point, we want to be woken up - // by the connection when more becomes available for us. - shared.reader = Some(cx.waker().clone()); - - Poll::Pending - } + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + if !self.config.read_after_close && self.sender.is_closed() { + return Poll::Ready(Ok(0)); + } + + match self.send_window_update(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + // Continue reading buffered data even though sending a window update blocked. + Poll::Pending => {} + } + + // Copy data from stream buffer. + let mut shared = self.shared(); + let mut n = 0; + while let Some(chunk) = shared.buffer.front_mut() { + if chunk.is_empty() { + shared.buffer.pop(); + continue; + } + let k = std::cmp::min(chunk.len(), buf.len() - n); + buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]); + n += k; + chunk.advance(k); + if n == buf.len() { + break; + } + } + + if n > 0 { + tracing::trace!(target: LOG_TARGET,"{}/{}: read {} bytes", self.conn, self.id, n); + return Poll::Ready(Ok(n)); + } + + // Buffer is empty, let's check if we can expect to read more data. + if !shared.state().can_read() { + tracing::debug!(target: LOG_TARGET,"{}/{}: eof", self.conn, self.id); + return Poll::Ready(Ok(0)); // stream has been reset + } + + // Since we have no more data at this point, we want to be woken up + // by the connection when more becomes available for us. + shared.reader = Some(cx.waker().clone()); + + Poll::Pending + } } impl AsyncWrite for Stream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); - let body = { - let mut shared = self.shared(); - if !shared.state().can_write() { - tracing::debug!(target: LOG_TARGET,"{}/{}: can no longer write", self.conn, self.id); - return Poll::Ready(Err(self.write_zero_err())); - } - if shared.credit == 0 { - tracing::trace!(target: LOG_TARGET,"{}/{}: no more credit left", self.conn, self.id); - shared.writer = Some(cx.waker().clone()); - return Poll::Pending; - } - let k = std::cmp::min(shared.credit as usize, buf.len()); - let k = std::cmp::min(k, self.config.split_send_size); - shared.credit = shared.credit.saturating_sub(k as u32); - Vec::from(&buf[..k]) - }; - let n = body.len(); - let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); - self.add_flag(frame.header_mut()); - tracing::trace!(target: LOG_TARGET,"{}/{}: write {} bytes", self.conn, self.id, n); - - // technically, the frame hasn't been sent yet on the wire but from the perspective of this - // data structure, we've queued the frame for sending We are tracking this - // information: a) to be consistent with outbound streams - // b) to correctly test our behaviour around timing of when ACKs are sent. See - // `ack_timing.rs` test. - if frame.header().flags().contains(ACK) { - self.shared() - .update_state(self.conn, self.id, State::Open { acknowledged: true }); - } - - let cmd = StreamCommand::SendFrame(frame); - self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; - Poll::Ready(Ok(n)) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.sender.poll_flush_unpin(cx).map_err(|_| self.write_zero_err()) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - if self.is_closed() { - return Poll::Ready(Ok(())); - } - ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); - let ack = if self.flag == Flag::Ack { - self.flag = Flag::None; - true - } else { - false - }; - tracing::trace!(target: LOG_TARGET,"{}/{}: close", self.conn, self.id); - let cmd = StreamCommand::CloseStream { ack }; - self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; - self.shared().update_state(self.conn, self.id, State::SendClosed); - Poll::Ready(Ok(())) - } + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); + let body = { + let mut shared = self.shared(); + if !shared.state().can_write() { + tracing::debug!(target: LOG_TARGET,"{}/{}: can no longer write", self.conn, self.id); + return Poll::Ready(Err(self.write_zero_err())); + } + if shared.credit == 0 { + tracing::trace!(target: LOG_TARGET,"{}/{}: no more credit left", self.conn, self.id); + shared.writer = Some(cx.waker().clone()); + return Poll::Pending; + } + let k = std::cmp::min(shared.credit as usize, buf.len()); + let k = std::cmp::min(k, self.config.split_send_size); + shared.credit = shared.credit.saturating_sub(k as u32); + Vec::from(&buf[..k]) + }; + let n = body.len(); + let mut frame = Frame::data(self.id, body).expect("body <= u32::MAX").left(); + self.add_flag(frame.header_mut()); + tracing::trace!(target: LOG_TARGET,"{}/{}: write {} bytes", self.conn, self.id, n); + + // technically, the frame hasn't been sent yet on the wire but from the perspective of this + // data structure, we've queued the frame for sending We are tracking this + // information: a) to be consistent with outbound streams + // b) to correctly test our behaviour around timing of when ACKs are sent. See + // `ack_timing.rs` test. + if frame.header().flags().contains(ACK) { + self.shared() + .update_state(self.conn, self.id, State::Open { acknowledged: true }); + } + + let cmd = StreamCommand::SendFrame(frame); + self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; + Poll::Ready(Ok(n)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.sender.poll_flush_unpin(cx).map_err(|_| self.write_zero_err()) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if self.is_closed() { + return Poll::Ready(Ok(())); + } + ready!(self.sender.poll_ready(cx).map_err(|_| self.write_zero_err())?); + let ack = if self.flag == Flag::Ack { + self.flag = Flag::None; + true + } else { + false + }; + tracing::trace!(target: LOG_TARGET,"{}/{}: close", self.conn, self.id); + let cmd = StreamCommand::CloseStream { ack }; + self.sender.start_send(cmd).map_err(|_| self.write_zero_err())?; + self.shared().update_state(self.conn, self.id, State::SendClosed); + Poll::Ready(Ok(())) + } } #[derive(Debug)] pub(crate) struct Shared { - state: State, - pub(crate) window: u32, - pub(crate) credit: u32, - pub(crate) buffer: Chunks, - pub(crate) reader: Option, - pub(crate) writer: Option, - config: Arc, + state: State, + pub(crate) window: u32, + pub(crate) credit: u32, + pub(crate) buffer: Chunks, + pub(crate) reader: Option, + pub(crate) writer: Option, + config: Arc, } impl Shared { - fn new(window: u32, credit: u32, config: Arc) -> Self { - Shared { - state: State::Open { acknowledged: false }, - window, - credit, - buffer: Chunks::new(), - reader: None, - writer: None, - config, - } - } - - pub(crate) fn state(&self) -> State { - self.state - } - - /// Update the stream state and return the state before it was updated. - pub(crate) fn update_state( - &mut self, - cid: connection::Id, - sid: StreamId, - next: State, - ) -> State { - use self::State::*; - - let current = self.state; - - match (current, next) { - (Closed, _) => {}, - (Open { .. }, _) => self.state = next, - (RecvClosed, Closed) => self.state = Closed, - (RecvClosed, Open { .. }) => {}, - (RecvClosed, RecvClosed) => {}, - (RecvClosed, SendClosed) => self.state = Closed, - (SendClosed, Closed) => self.state = Closed, - (SendClosed, Open { .. }) => {}, - (SendClosed, RecvClosed) => self.state = Closed, - (SendClosed, SendClosed) => {}, - } - - tracing::trace!(target: LOG_TARGET, - "{}/{}: update state: (from {:?} to {:?} -> {:?})", - cid, - sid, - current, - next, - self.state - ); - - current // Return the previous stream state for informational purposes. - } - - /// Calculate the number of additional window bytes the receiving side - /// should grant the sending side via a window update message. - /// - /// Returns `None` if too small to justify a window update message. - /// - /// Note: Once a caller successfully sent a window update message, the - /// locally tracked window size needs to be updated manually by the caller. - pub(crate) fn next_window_update(&mut self) -> Option { - if !self.state.can_read() { - return None; - } - - let new_credit = match self.config.window_update_mode { - WindowUpdateMode::OnReceive => { - debug_assert!(self.config.receive_window >= self.window); - - self.config.receive_window.saturating_sub(self.window) - }, - WindowUpdateMode::OnRead => { - debug_assert!(self.config.receive_window >= self.window); - let bytes_received = self.config.receive_window.saturating_sub(self.window); - let buffer_len: u32 = self.buffer.len().try_into().unwrap_or(std::u32::MAX); - - bytes_received.saturating_sub(buffer_len) - }, - }; - - // Send WindowUpdate message when half or more of the configured receive - // window can be granted as additional credit to the sender. - // - // See https://github.com/paritytech/yamux/issues/100 for a detailed - // discussion. - if new_credit >= self.config.receive_window / 2 { - Some(new_credit) - } else { - None - } - } - - /// Whether we are still waiting for the remote to acknowledge this stream. - pub fn is_pending_ack(&self) -> bool { - matches!(self.state(), State::Open { acknowledged: false }) - } + fn new(window: u32, credit: u32, config: Arc) -> Self { + Shared { + state: State::Open { + acknowledged: false, + }, + window, + credit, + buffer: Chunks::new(), + reader: None, + writer: None, + config, + } + } + + pub(crate) fn state(&self) -> State { + self.state + } + + /// Update the stream state and return the state before it was updated. + pub(crate) fn update_state( + &mut self, + cid: connection::Id, + sid: StreamId, + next: State, + ) -> State { + use self::State::*; + + let current = self.state; + + match (current, next) { + (Closed, _) => {} + (Open { .. }, _) => self.state = next, + (RecvClosed, Closed) => self.state = Closed, + (RecvClosed, Open { .. }) => {} + (RecvClosed, RecvClosed) => {} + (RecvClosed, SendClosed) => self.state = Closed, + (SendClosed, Closed) => self.state = Closed, + (SendClosed, Open { .. }) => {} + (SendClosed, RecvClosed) => self.state = Closed, + (SendClosed, SendClosed) => {} + } + + tracing::trace!(target: LOG_TARGET, + "{}/{}: update state: (from {:?} to {:?} -> {:?})", + cid, + sid, + current, + next, + self.state + ); + + current // Return the previous stream state for informational purposes. + } + + /// Calculate the number of additional window bytes the receiving side + /// should grant the sending side via a window update message. + /// + /// Returns `None` if too small to justify a window update message. + /// + /// Note: Once a caller successfully sent a window update message, the + /// locally tracked window size needs to be updated manually by the caller. + pub(crate) fn next_window_update(&mut self) -> Option { + if !self.state.can_read() { + return None; + } + + let new_credit = match self.config.window_update_mode { + WindowUpdateMode::OnReceive => { + debug_assert!(self.config.receive_window >= self.window); + + self.config.receive_window.saturating_sub(self.window) + } + WindowUpdateMode::OnRead => { + debug_assert!(self.config.receive_window >= self.window); + let bytes_received = self.config.receive_window.saturating_sub(self.window); + let buffer_len: u32 = self.buffer.len().try_into().unwrap_or(std::u32::MAX); + + bytes_received.saturating_sub(buffer_len) + } + }; + + // Send WindowUpdate message when half or more of the configured receive + // window can be granted as additional credit to the sender. + // + // See https://github.com/paritytech/yamux/issues/100 for a detailed + // discussion. + if new_credit >= self.config.receive_window / 2 { + Some(new_credit) + } else { + None + } + } + + /// Whether we are still waiting for the remote to acknowledge this stream. + pub fn is_pending_ack(&self) -> bool { + matches!( + self.state(), + State::Open { + acknowledged: false + } + ) + } } diff --git a/src/yamux/control.rs b/src/yamux/control.rs index bb6ba9c5..89e3c350 100644 --- a/src/yamux/control.rs +++ b/src/yamux/control.rs @@ -10,12 +10,12 @@ use crate::yamux::{error::ConnectionError, Connection, Result, Stream, MAX_ACK_BACKLOG}; use futures::{ - channel::{mpsc, oneshot}, - prelude::*, + channel::{mpsc, oneshot}, + prelude::*, }; use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; /// A Yamux [`Connection`] controller. @@ -26,192 +26,198 @@ use std::{ /// a [`Control`] to be cloned and shared between tasks and threads. #[derive(Clone, Debug)] pub struct Control { - /// Command channel to [`ControlledConnection`]. - sender: mpsc::Sender, + /// Command channel to [`ControlledConnection`]. + sender: mpsc::Sender, } impl Control { - pub fn new(connection: Connection) -> (Self, ControlledConnection) { - let (sender, receiver) = mpsc::channel(MAX_ACK_BACKLOG); - - let control = Control { sender }; - let connection = - ControlledConnection { state: State::Idle(connection), commands: receiver }; - - (control, connection) - } - - /// Open a new stream to the remote. - pub async fn open_stream(&mut self) -> Result { - let (tx, rx) = oneshot::channel(); - self.sender.send(ControlCommand::OpenStream(tx)).await?; - rx.await? - } - - /// Close the connection. - pub async fn close(&mut self) -> Result<()> { - let (tx, rx) = oneshot::channel(); - if self.sender.send(ControlCommand::CloseConnection(tx)).await.is_err() { - // The receiver is closed which means the connection is already closed. - return Ok(()); - } - // A dropped `oneshot::Sender` means the `Connection` is gone, - // so we do not treat receive errors differently here. - let _ = rx.await; - Ok(()) - } + pub fn new(connection: Connection) -> (Self, ControlledConnection) { + let (sender, receiver) = mpsc::channel(MAX_ACK_BACKLOG); + + let control = Control { sender }; + let connection = ControlledConnection { + state: State::Idle(connection), + commands: receiver, + }; + + (control, connection) + } + + /// Open a new stream to the remote. + pub async fn open_stream(&mut self) -> Result { + let (tx, rx) = oneshot::channel(); + self.sender.send(ControlCommand::OpenStream(tx)).await?; + rx.await? + } + + /// Close the connection. + pub async fn close(&mut self) -> Result<()> { + let (tx, rx) = oneshot::channel(); + if self.sender.send(ControlCommand::CloseConnection(tx)).await.is_err() { + // The receiver is closed which means the connection is already closed. + return Ok(()); + } + // A dropped `oneshot::Sender` means the `Connection` is gone, + // so we do not treat receive errors differently here. + let _ = rx.await; + Ok(()) + } } /// Wraps a [`Connection`] which can be controlled with a [`Control`]. pub struct ControlledConnection { - state: State, - commands: mpsc::Receiver, + state: State, + commands: mpsc::Receiver, } impl ControlledConnection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { - loop { - match std::mem::replace(&mut self.state, State::Poisoned) { - State::Idle(mut connection) => { - match connection.poll_next_inbound(cx) { - Poll::Ready(maybe_stream) => { - self.state = State::Idle(connection); - return Poll::Ready(maybe_stream); - }, - Poll::Pending => {}, - } - - match self.commands.poll_next_unpin(cx) { - Poll::Ready(Some(ControlCommand::OpenStream(reply))) => { - self.state = State::OpeningNewStream { reply, connection }; - continue; - }, - Poll::Ready(Some(ControlCommand::CloseConnection(reply))) => { - self.commands.close(); - - self.state = State::Closing { - reply: Some(reply), - inner: Closing::DrainingControlCommands { connection }, - }; - continue; - }, - Poll::Ready(None) => { - // Last `Control` sender was dropped, close te connection. - self.state = State::Closing { - reply: None, - inner: Closing::ClosingConnection { connection }, - }; - continue; - }, - Poll::Pending => {}, - } - - self.state = State::Idle(connection); - return Poll::Pending; - }, - State::OpeningNewStream { reply, mut connection } => - match connection.poll_new_outbound(cx) { - Poll::Ready(stream) => { - let _ = reply.send(stream); - - self.state = State::Idle(connection); - continue; - }, - Poll::Pending => { - self.state = State::OpeningNewStream { reply, connection }; - return Poll::Pending; - }, - }, - State::Closing { - reply, - inner: Closing::DrainingControlCommands { connection }, - } => match self.commands.poll_next_unpin(cx) { - Poll::Ready(Some(ControlCommand::OpenStream(new_reply))) => { - let _ = new_reply.send(Err(ConnectionError::Closed)); - - self.state = State::Closing { - reply, - inner: Closing::DrainingControlCommands { connection }, - }; - continue; - }, - Poll::Ready(Some(ControlCommand::CloseConnection(new_reply))) => { - let _ = new_reply.send(()); - - self.state = State::Closing { - reply, - inner: Closing::DrainingControlCommands { connection }, - }; - continue; - }, - Poll::Ready(None) => { - self.state = State::Closing { - reply, - inner: Closing::ClosingConnection { connection }, - }; - continue; - }, - Poll::Pending => { - self.state = State::Closing { - reply, - inner: Closing::DrainingControlCommands { connection }, - }; - return Poll::Pending; - }, - }, - State::Closing { reply, inner: Closing::ClosingConnection { mut connection } } => - match connection.poll_close(cx) { - Poll::Ready(Ok(())) | Poll::Ready(Err(ConnectionError::Closed)) => { - if let Some(reply) = reply { - let _ = reply.send(()); - } - return Poll::Ready(None); - }, - Poll::Ready(Err(other)) => { - if let Some(reply) = reply { - let _ = reply.send(()); - } - return Poll::Ready(Some(Err(other))); - }, - Poll::Pending => { - self.state = State::Closing { - reply, - inner: Closing::ClosingConnection { connection }, - }; - return Poll::Pending; - }, - }, - State::Poisoned => unreachable!(), - } - } - } + fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + match std::mem::replace(&mut self.state, State::Poisoned) { + State::Idle(mut connection) => { + match connection.poll_next_inbound(cx) { + Poll::Ready(maybe_stream) => { + self.state = State::Idle(connection); + return Poll::Ready(maybe_stream); + } + Poll::Pending => {} + } + + match self.commands.poll_next_unpin(cx) { + Poll::Ready(Some(ControlCommand::OpenStream(reply))) => { + self.state = State::OpeningNewStream { reply, connection }; + continue; + } + Poll::Ready(Some(ControlCommand::CloseConnection(reply))) => { + self.commands.close(); + + self.state = State::Closing { + reply: Some(reply), + inner: Closing::DrainingControlCommands { connection }, + }; + continue; + } + Poll::Ready(None) => { + // Last `Control` sender was dropped, close te connection. + self.state = State::Closing { + reply: None, + inner: Closing::ClosingConnection { connection }, + }; + continue; + } + Poll::Pending => {} + } + + self.state = State::Idle(connection); + return Poll::Pending; + } + State::OpeningNewStream { + reply, + mut connection, + } => match connection.poll_new_outbound(cx) { + Poll::Ready(stream) => { + let _ = reply.send(stream); + + self.state = State::Idle(connection); + continue; + } + Poll::Pending => { + self.state = State::OpeningNewStream { reply, connection }; + return Poll::Pending; + } + }, + State::Closing { + reply, + inner: Closing::DrainingControlCommands { connection }, + } => match self.commands.poll_next_unpin(cx) { + Poll::Ready(Some(ControlCommand::OpenStream(new_reply))) => { + let _ = new_reply.send(Err(ConnectionError::Closed)); + + self.state = State::Closing { + reply, + inner: Closing::DrainingControlCommands { connection }, + }; + continue; + } + Poll::Ready(Some(ControlCommand::CloseConnection(new_reply))) => { + let _ = new_reply.send(()); + + self.state = State::Closing { + reply, + inner: Closing::DrainingControlCommands { connection }, + }; + continue; + } + Poll::Ready(None) => { + self.state = State::Closing { + reply, + inner: Closing::ClosingConnection { connection }, + }; + continue; + } + Poll::Pending => { + self.state = State::Closing { + reply, + inner: Closing::DrainingControlCommands { connection }, + }; + return Poll::Pending; + } + }, + State::Closing { + reply, + inner: Closing::ClosingConnection { mut connection }, + } => match connection.poll_close(cx) { + Poll::Ready(Ok(())) | Poll::Ready(Err(ConnectionError::Closed)) => { + if let Some(reply) = reply { + let _ = reply.send(()); + } + return Poll::Ready(None); + } + Poll::Ready(Err(other)) => { + if let Some(reply) = reply { + let _ = reply.send(()); + } + return Poll::Ready(Some(Err(other))); + } + Poll::Pending => { + self.state = State::Closing { + reply, + inner: Closing::ClosingConnection { connection }, + }; + return Poll::Pending; + } + }, + State::Poisoned => unreachable!(), + } + } + } } #[derive(Debug)] enum ControlCommand { - /// Open a new stream to the remote end. - OpenStream(oneshot::Sender>), - /// Close the whole connection. - CloseConnection(oneshot::Sender<()>), + /// Open a new stream to the remote end. + OpenStream(oneshot::Sender>), + /// Close the whole connection. + CloseConnection(oneshot::Sender<()>), } /// The state of a [`ControlledConnection`]. enum State { - Idle(Connection), - OpeningNewStream { - reply: oneshot::Sender>, - connection: Connection, - }, - Closing { - /// A channel to the [`Control`] in case the close was requested. `None` if we are closing - /// because the last [`Control`] was dropped. - reply: Option>, - inner: Closing, - }, - Poisoned, + Idle(Connection), + OpeningNewStream { + reply: oneshot::Sender>, + connection: Connection, + }, + Closing { + /// A channel to the [`Control`] in case the close was requested. `None` if we are closing + /// because the last [`Control`] was dropped. + reply: Option>, + inner: Closing, + }, + Poisoned, } /// A sub-state of our larger state machine for a [`ControlledConnection`]. @@ -221,17 +227,17 @@ enum State { /// 1. Draining and answered all remaining [`Closing::DrainingControlCommands`]. /// 1. Closing the underlying [`Connection`]. enum Closing { - DrainingControlCommands { connection: Connection }, - ClosingConnection { connection: Connection }, + DrainingControlCommands { connection: Connection }, + ClosingConnection { connection: Connection }, } impl futures::Stream for ControlledConnection where - T: AsyncRead + AsyncWrite + Unpin + Send + 'static, + T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Item = Result; + type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().poll_next(cx) - } + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().poll_next(cx) + } } diff --git a/src/yamux/error.rs b/src/yamux/error.rs index 275d9350..fb729457 100644 --- a/src/yamux/error.rs +++ b/src/yamux/error.rs @@ -14,63 +14,63 @@ use crate::yamux::frame::FrameDecodeError; #[non_exhaustive] #[derive(Debug)] pub enum ConnectionError { - /// An underlying I/O error occured. - Io(std::io::Error), - /// Decoding a Yamux message frame failed. - Decode(FrameDecodeError), - /// The whole range of stream IDs has been used up. - NoMoreStreamIds, - /// An operation fails because the connection is closed. - Closed, - /// Too many streams are open, so no further ones can be opened at this time. - TooManyStreams, + /// An underlying I/O error occured. + Io(std::io::Error), + /// Decoding a Yamux message frame failed. + Decode(FrameDecodeError), + /// The whole range of stream IDs has been used up. + NoMoreStreamIds, + /// An operation fails because the connection is closed. + Closed, + /// Too many streams are open, so no further ones can be opened at this time. + TooManyStreams, } impl std::fmt::Display for ConnectionError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - ConnectionError::Io(e) => write!(f, "i/o error: {}", e), - ConnectionError::Decode(e) => write!(f, "decode error: {}", e), - ConnectionError::NoMoreStreamIds => - f.write_str("number of stream ids has been exhausted"), - ConnectionError::Closed => f.write_str("connection is closed"), - ConnectionError::TooManyStreams => f.write_str("maximum number of streams reached"), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ConnectionError::Io(e) => write!(f, "i/o error: {}", e), + ConnectionError::Decode(e) => write!(f, "decode error: {}", e), + ConnectionError::NoMoreStreamIds => + f.write_str("number of stream ids has been exhausted"), + ConnectionError::Closed => f.write_str("connection is closed"), + ConnectionError::TooManyStreams => f.write_str("maximum number of streams reached"), + } + } } impl std::error::Error for ConnectionError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - ConnectionError::Io(e) => Some(e), - ConnectionError::Decode(e) => Some(e), - ConnectionError::NoMoreStreamIds | - ConnectionError::Closed | - ConnectionError::TooManyStreams => None, - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + ConnectionError::Io(e) => Some(e), + ConnectionError::Decode(e) => Some(e), + ConnectionError::NoMoreStreamIds + | ConnectionError::Closed + | ConnectionError::TooManyStreams => None, + } + } } impl From for ConnectionError { - fn from(e: std::io::Error) -> Self { - ConnectionError::Io(e) - } + fn from(e: std::io::Error) -> Self { + ConnectionError::Io(e) + } } impl From for ConnectionError { - fn from(e: FrameDecodeError) -> Self { - ConnectionError::Decode(e) - } + fn from(e: FrameDecodeError) -> Self { + ConnectionError::Decode(e) + } } impl From for ConnectionError { - fn from(_: futures::channel::mpsc::SendError) -> Self { - ConnectionError::Closed - } + fn from(_: futures::channel::mpsc::SendError) -> Self { + ConnectionError::Closed + } } impl From for ConnectionError { - fn from(_: futures::channel::oneshot::Canceled) -> Self { - ConnectionError::Closed - } + fn from(_: futures::channel::oneshot::Canceled) -> Self { + ConnectionError::Closed + } } diff --git a/src/yamux/frame.rs b/src/yamux/frame.rs index 3421e881..692840a4 100644 --- a/src/yamux/frame.rs +++ b/src/yamux/frame.rs @@ -21,100 +21,136 @@ pub(crate) use io::Io; /// A Yamux message frame consisting of header and body. #[derive(Clone, Debug, PartialEq, Eq)] pub struct Frame { - header: Header, - body: Vec, + header: Header, + body: Vec, } impl Frame { - pub fn new(header: Header) -> Self { - Frame { header, body: Vec::new() } - } - - pub fn header(&self) -> &Header { - &self.header - } - - pub fn header_mut(&mut self) -> &mut Header { - &mut self.header - } - - /// Introduce this frame to the right of a binary frame type. - pub(crate) fn right(self) -> Frame> { - Frame { header: self.header.right(), body: self.body } - } - - /// Introduce this frame to the left of a binary frame type. - pub(crate) fn left(self) -> Frame> { - Frame { header: self.header.left(), body: self.body } - } + pub fn new(header: Header) -> Self { + Frame { + header, + body: Vec::new(), + } + } + + pub fn header(&self) -> &Header { + &self.header + } + + pub fn header_mut(&mut self) -> &mut Header { + &mut self.header + } + + /// Introduce this frame to the right of a binary frame type. + pub(crate) fn right(self) -> Frame> { + Frame { + header: self.header.right(), + body: self.body, + } + } + + /// Introduce this frame to the left of a binary frame type. + pub(crate) fn left(self) -> Frame> { + Frame { + header: self.header.left(), + body: self.body, + } + } } impl From> for Frame<()> { - fn from(f: Frame) -> Frame<()> { - Frame { header: f.header.into(), body: f.body } - } + fn from(f: Frame) -> Frame<()> { + Frame { + header: f.header.into(), + body: f.body, + } + } } impl Frame<()> { - pub(crate) fn into_data(self) -> Frame { - Frame { header: self.header.into_data(), body: self.body } - } - - pub(crate) fn into_window_update(self) -> Frame { - Frame { header: self.header.into_window_update(), body: self.body } - } - - pub(crate) fn into_ping(self) -> Frame { - Frame { header: self.header.into_ping(), body: self.body } - } + pub(crate) fn into_data(self) -> Frame { + Frame { + header: self.header.into_data(), + body: self.body, + } + } + + pub(crate) fn into_window_update(self) -> Frame { + Frame { + header: self.header.into_window_update(), + body: self.body, + } + } + + pub(crate) fn into_ping(self) -> Frame { + Frame { + header: self.header.into_ping(), + body: self.body, + } + } } impl Frame { - pub fn data(id: StreamId, b: Vec) -> Result { - Ok(Frame { header: Header::data(id, b.len().try_into()?), body: b }) - } - - pub fn close_stream(id: StreamId, ack: bool) -> Self { - let mut header = Header::data(id, 0); - header.fin(); - if ack { - header.ack() - } - - Frame::new(header) - } - - pub fn body(&self) -> &[u8] { - &self.body - } - - pub fn body_len(&self) -> u32 { - // Safe cast since we construct `Frame::`s only with - // `Vec` of length [0, u32::MAX] in `Frame::data` above. - self.body().len() as u32 - } - - pub fn into_body(self) -> Vec { - self.body - } + pub fn data(id: StreamId, b: Vec) -> Result { + Ok(Frame { + header: Header::data(id, b.len().try_into()?), + body: b, + }) + } + + pub fn close_stream(id: StreamId, ack: bool) -> Self { + let mut header = Header::data(id, 0); + header.fin(); + if ack { + header.ack() + } + + Frame::new(header) + } + + pub fn body(&self) -> &[u8] { + &self.body + } + + pub fn body_len(&self) -> u32 { + // Safe cast since we construct `Frame::`s only with + // `Vec` of length [0, u32::MAX] in `Frame::data` above. + self.body().len() as u32 + } + + pub fn into_body(self) -> Vec { + self.body + } } impl Frame { - pub fn window_update(id: StreamId, credit: u32) -> Self { - Frame { header: Header::window_update(id, credit), body: Vec::new() } - } + pub fn window_update(id: StreamId, credit: u32) -> Self { + Frame { + header: Header::window_update(id, credit), + body: Vec::new(), + } + } } impl Frame { - pub fn term() -> Self { - Frame { header: Header::term(), body: Vec::new() } - } - - pub fn protocol_error() -> Self { - Frame { header: Header::protocol_error(), body: Vec::new() } - } - - pub fn internal_error() -> Self { - Frame { header: Header::internal_error(), body: Vec::new() } - } + pub fn term() -> Self { + Frame { + header: Header::term(), + body: Vec::new(), + } + } + + pub fn protocol_error() -> Self { + Frame { + header: Header::protocol_error(), + body: Vec::new(), + } + } + + pub fn internal_error() -> Self { + Frame { + header: Header::internal_error(), + body: Vec::new(), + } + } } diff --git a/src/yamux/frame/header.rs b/src/yamux/frame/header.rs index 89792a8d..aaf825a7 100644 --- a/src/yamux/frame/header.rs +++ b/src/yamux/frame/header.rs @@ -14,201 +14,201 @@ use std::fmt; /// The message frame header. #[derive(Clone, Debug, PartialEq, Eq)] pub struct Header { - version: Version, - tag: Tag, - flags: Flags, - stream_id: StreamId, - length: Len, - _marker: std::marker::PhantomData, + version: Version, + tag: Tag, + flags: Flags, + stream_id: StreamId, + length: Len, + _marker: std::marker::PhantomData, } impl fmt::Display for Header { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "(Header {:?} {} (len {}) (flags {:?}))", - self.tag, - self.stream_id, - self.length.val(), - self.flags.val() - ) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "(Header {:?} {} (len {}) (flags {:?}))", + self.tag, + self.stream_id, + self.length.val(), + self.flags.val() + ) + } } impl Header { - pub fn tag(&self) -> Tag { - self.tag - } - - pub fn flags(&self) -> Flags { - self.flags - } - - pub fn stream_id(&self) -> StreamId { - self.stream_id - } - - pub fn len(&self) -> Len { - self.length - } - - #[cfg(test)] - pub fn set_len(&mut self, len: u32) { - self.length = Len(len) - } - - /// Arbitrary type cast, use with caution. - fn cast(self) -> Header { - Header { - version: self.version, - tag: self.tag, - flags: self.flags, - stream_id: self.stream_id, - length: self.length, - _marker: std::marker::PhantomData, - } - } - - /// Introduce this header to the right of a binary header type. - pub(crate) fn right(self) -> Header> { - self.cast() - } - - /// Introduce this header to the left of a binary header type. - pub(crate) fn left(self) -> Header> { - self.cast() - } + pub fn tag(&self) -> Tag { + self.tag + } + + pub fn flags(&self) -> Flags { + self.flags + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn len(&self) -> Len { + self.length + } + + #[cfg(test)] + pub fn set_len(&mut self, len: u32) { + self.length = Len(len) + } + + /// Arbitrary type cast, use with caution. + fn cast(self) -> Header { + Header { + version: self.version, + tag: self.tag, + flags: self.flags, + stream_id: self.stream_id, + length: self.length, + _marker: std::marker::PhantomData, + } + } + + /// Introduce this header to the right of a binary header type. + pub(crate) fn right(self) -> Header> { + self.cast() + } + + /// Introduce this header to the left of a binary header type. + pub(crate) fn left(self) -> Header> { + self.cast() + } } impl From> for Header<()> { - fn from(h: Header) -> Header<()> { - h.cast() - } + fn from(h: Header) -> Header<()> { + h.cast() + } } impl Header<()> { - pub(crate) fn into_data(self) -> Header { - debug_assert_eq!(self.tag, Tag::Data); - self.cast() - } - - pub(crate) fn into_window_update(self) -> Header { - debug_assert_eq!(self.tag, Tag::WindowUpdate); - self.cast() - } - - pub(crate) fn into_ping(self) -> Header { - debug_assert_eq!(self.tag, Tag::Ping); - self.cast() - } + pub(crate) fn into_data(self) -> Header { + debug_assert_eq!(self.tag, Tag::Data); + self.cast() + } + + pub(crate) fn into_window_update(self) -> Header { + debug_assert_eq!(self.tag, Tag::WindowUpdate); + self.cast() + } + + pub(crate) fn into_ping(self) -> Header { + debug_assert_eq!(self.tag, Tag::Ping); + self.cast() + } } impl Header { - /// Set the [`SYN`] flag. - pub fn syn(&mut self) { - self.flags.0 |= SYN.0 - } + /// Set the [`SYN`] flag. + pub fn syn(&mut self) { + self.flags.0 |= SYN.0 + } } impl Header { - /// Set the [`ACK`] flag. - pub fn ack(&mut self) { - self.flags.0 |= ACK.0 - } + /// Set the [`ACK`] flag. + pub fn ack(&mut self) { + self.flags.0 |= ACK.0 + } } impl Header { - /// Set the [`FIN`] flag. - pub fn fin(&mut self) { - self.flags.0 |= FIN.0 - } + /// Set the [`FIN`] flag. + pub fn fin(&mut self) { + self.flags.0 |= FIN.0 + } } impl Header { - /// Set the [`RST`] flag. - pub fn rst(&mut self) { - self.flags.0 |= RST.0 - } + /// Set the [`RST`] flag. + pub fn rst(&mut self) { + self.flags.0 |= RST.0 + } } impl Header { - /// Create a new data frame header. - pub fn data(id: StreamId, len: u32) -> Self { - Header { - version: Version(0), - tag: Tag::Data, - flags: Flags(0), - stream_id: id, - length: Len(len), - _marker: std::marker::PhantomData, - } - } + /// Create a new data frame header. + pub fn data(id: StreamId, len: u32) -> Self { + Header { + version: Version(0), + tag: Tag::Data, + flags: Flags(0), + stream_id: id, + length: Len(len), + _marker: std::marker::PhantomData, + } + } } impl Header { - /// Create a new window update frame header. - pub fn window_update(id: StreamId, credit: u32) -> Self { - Header { - version: Version(0), - tag: Tag::WindowUpdate, - flags: Flags(0), - stream_id: id, - length: Len(credit), - _marker: std::marker::PhantomData, - } - } - - /// The credit this window update grants to the remote. - pub fn credit(&self) -> u32 { - self.length.0 - } + /// Create a new window update frame header. + pub fn window_update(id: StreamId, credit: u32) -> Self { + Header { + version: Version(0), + tag: Tag::WindowUpdate, + flags: Flags(0), + stream_id: id, + length: Len(credit), + _marker: std::marker::PhantomData, + } + } + + /// The credit this window update grants to the remote. + pub fn credit(&self) -> u32 { + self.length.0 + } } impl Header { - /// Create a new ping frame header. - pub fn ping(nonce: u32) -> Self { - Header { - version: Version(0), - tag: Tag::Ping, - flags: Flags(0), - stream_id: StreamId(0), - length: Len(nonce), - _marker: std::marker::PhantomData, - } - } - - /// The nonce of this ping. - pub fn nonce(&self) -> u32 { - self.length.0 - } + /// Create a new ping frame header. + pub fn ping(nonce: u32) -> Self { + Header { + version: Version(0), + tag: Tag::Ping, + flags: Flags(0), + stream_id: StreamId(0), + length: Len(nonce), + _marker: std::marker::PhantomData, + } + } + + /// The nonce of this ping. + pub fn nonce(&self) -> u32 { + self.length.0 + } } impl Header { - /// Terminate the session without indicating an error to the remote. - pub fn term() -> Self { - Self::go_away(0) - } - - /// Terminate the session indicating a protocol error to the remote. - pub fn protocol_error() -> Self { - Self::go_away(1) - } - - /// Terminate the session indicating an internal error to the remote. - pub fn internal_error() -> Self { - Self::go_away(2) - } - - fn go_away(code: u32) -> Self { - Header { - version: Version(0), - tag: Tag::GoAway, - flags: Flags(0), - stream_id: StreamId(0), - length: Len(code), - _marker: std::marker::PhantomData, - } - } + /// Terminate the session without indicating an error to the remote. + pub fn term() -> Self { + Self::go_away(0) + } + + /// Terminate the session indicating a protocol error to the remote. + pub fn protocol_error() -> Self { + Self::go_away(1) + } + + /// Terminate the session indicating an internal error to the remote. + pub fn internal_error() -> Self { + Self::go_away(2) + } + + fn go_away(code: u32) -> Self { + Header { + version: Version(0), + tag: Tag::GoAway, + flags: Flags(0), + stream_id: StreamId(0), + length: Len(code), + _marker: std::marker::PhantomData, + } + } } /// Data message type. @@ -252,22 +252,22 @@ impl HasRst for Data {} impl HasRst for WindowUpdate {} pub(super) mod private { - pub trait Sealed {} + pub trait Sealed {} - impl Sealed for super::Data {} - impl Sealed for super::WindowUpdate {} - impl Sealed for super::Ping {} - impl Sealed for super::GoAway {} - impl Sealed for super::Either {} + impl Sealed for super::Data {} + impl Sealed for super::WindowUpdate {} + impl Sealed for super::Ping {} + impl Sealed for super::GoAway {} + impl Sealed for super::Either {} } /// A tag is the runtime representation of a message type. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum Tag { - Data, - WindowUpdate, - Ping, - GoAway, + Data, + WindowUpdate, + Ping, + GoAway, } /// The protocol version a message corresponds to. @@ -279,9 +279,9 @@ pub struct Version(u8); pub struct Len(u32); impl Len { - pub fn val(self) -> u32 { - self.0 - } + pub fn val(self) -> u32 { + self.0 + } } pub const CONNECTION_ID: StreamId = StreamId(0); @@ -293,31 +293,31 @@ pub const CONNECTION_ID: StreamId = StreamId(0); pub struct StreamId(u32); impl StreamId { - pub(crate) fn new(val: u32) -> Self { - StreamId(val) - } + pub(crate) fn new(val: u32) -> Self { + StreamId(val) + } - pub fn is_server(self) -> bool { - self.0 % 2 == 0 - } + pub fn is_server(self) -> bool { + self.0 % 2 == 0 + } - pub fn is_client(self) -> bool { - !self.is_server() - } + pub fn is_client(self) -> bool { + !self.is_server() + } - pub fn is_session(self) -> bool { - self == CONNECTION_ID - } + pub fn is_session(self) -> bool { + self == CONNECTION_ID + } - pub fn val(self) -> u32 { - self.0 - } + pub fn val(self) -> u32 { + self.0 + } } impl fmt::Display for StreamId { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.0) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } } impl nohash_hasher::IsEnabled for StreamId {} @@ -327,13 +327,13 @@ impl nohash_hasher::IsEnabled for StreamId {} pub struct Flags(u16); impl Flags { - pub fn contains(self, other: Flags) -> bool { - self.0 & other.0 == other.0 - } + pub fn contains(self, other: Flags) -> bool { + self.0 & other.0 == other.0 + } - pub fn val(self) -> u16 { - self.0 - } + pub fn val(self) -> u16 { + self.0 + } } /// Indicates the start of a new stream. @@ -353,91 +353,91 @@ pub const HEADER_SIZE: usize = 12; /// Encode a [`Header`] value. pub fn encode(hdr: &Header) -> [u8; HEADER_SIZE] { - let mut buf = [0; HEADER_SIZE]; - buf[0] = hdr.version.0; - buf[1] = hdr.tag as u8; - buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes()); - buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes()); - buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes()); - buf + let mut buf = [0; HEADER_SIZE]; + buf[0] = hdr.version.0; + buf[1] = hdr.tag as u8; + buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes()); + buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes()); + buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes()); + buf } /// Decode a [`Header`] value. pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result, HeaderDecodeError> { - if buf[0] != 0 { - return Err(HeaderDecodeError::Version(buf[0])); - } - - let hdr = Header { - version: Version(buf[0]), - tag: match buf[1] { - 0 => Tag::Data, - 1 => Tag::WindowUpdate, - 2 => Tag::Ping, - 3 => Tag::GoAway, - t => return Err(HeaderDecodeError::Type(t)), - }, - flags: Flags(u16::from_be_bytes([buf[2], buf[3]])), - stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])), - length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])), - _marker: std::marker::PhantomData, - }; - - Ok(hdr) + if buf[0] != 0 { + return Err(HeaderDecodeError::Version(buf[0])); + } + + let hdr = Header { + version: Version(buf[0]), + tag: match buf[1] { + 0 => Tag::Data, + 1 => Tag::WindowUpdate, + 2 => Tag::Ping, + 3 => Tag::GoAway, + t => return Err(HeaderDecodeError::Type(t)), + }, + flags: Flags(u16::from_be_bytes([buf[2], buf[3]])), + stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])), + length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])), + _marker: std::marker::PhantomData, + }; + + Ok(hdr) } /// Possible errors while decoding a message frame header. #[non_exhaustive] #[derive(Debug)] pub enum HeaderDecodeError { - /// Unknown version. - Version(u8), - /// An unknown frame type. - Type(u8), + /// Unknown version. + Version(u8), + /// An unknown frame type. + Type(u8), } impl std::fmt::Display for HeaderDecodeError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - HeaderDecodeError::Version(v) => write!(f, "unknown version: {}", v), - HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {}", t), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + HeaderDecodeError::Version(v) => write!(f, "unknown version: {}", v), + HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {}", t), + } + } } impl std::error::Error for HeaderDecodeError {} #[cfg(test)] mod tests { - use super::*; - use quickcheck::{Arbitrary, Gen, QuickCheck}; - - impl Arbitrary for Header<()> { - fn arbitrary(g: &mut Gen) -> Self { - let tag = *g.choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]).unwrap(); - - Header { - version: Version(0), - tag, - flags: Flags(Arbitrary::arbitrary(g)), - stream_id: StreamId(Arbitrary::arbitrary(g)), - length: Len(Arbitrary::arbitrary(g)), - _marker: std::marker::PhantomData, - } - } - } - - #[test] - fn encode_decode_identity() { - fn property(hdr: Header<()>) -> bool { - match decode(&encode(&hdr)) { - Ok(x) => x == hdr, - Err(e) => { - eprintln!("decode error: {}", e); - false - }, - } - } - QuickCheck::new().tests(10_000).quickcheck(property as fn(Header<()>) -> bool) - } + use super::*; + use quickcheck::{Arbitrary, Gen, QuickCheck}; + + impl Arbitrary for Header<()> { + fn arbitrary(g: &mut Gen) -> Self { + let tag = *g.choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]).unwrap(); + + Header { + version: Version(0), + tag, + flags: Flags(Arbitrary::arbitrary(g)), + stream_id: StreamId(Arbitrary::arbitrary(g)), + length: Len(Arbitrary::arbitrary(g)), + _marker: std::marker::PhantomData, + } + } + } + + #[test] + fn encode_decode_identity() { + fn property(hdr: Header<()>) -> bool { + match decode(&encode(&hdr)) { + Ok(x) => x == hdr, + Err(e) => { + eprintln!("decode error: {}", e); + false + } + } + } + QuickCheck::new().tests(10_000).quickcheck(property as fn(Header<()>) -> bool) + } } diff --git a/src/yamux/frame/io.rs b/src/yamux/frame/io.rs index 80222d78..a0c67445 100644 --- a/src/yamux/frame/io.rs +++ b/src/yamux/frame/io.rs @@ -9,15 +9,15 @@ // at https://opensource.org/licenses/MIT. use super::{ - header::{self, HeaderDecodeError}, - Frame, + header::{self, HeaderDecodeError}, + Frame, }; use crate::yamux::connection::Id; use futures::{prelude::*, ready}; use std::{ - fmt, io, - pin::Pin, - task::{Context, Poll}, + fmt, io, + pin::Pin, + task::{Context, Poll}, }; /// Logging target for the file. @@ -26,304 +26,348 @@ const LOG_TARGET: &str = "litep2p::yamux"; /// A [`Stream`] and writer of [`Frame`] values. #[derive(Debug)] pub(crate) struct Io { - id: Id, - io: T, - read_state: ReadState, - write_state: WriteState, - max_body_len: usize, + id: Id, + io: T, + read_state: ReadState, + write_state: WriteState, + max_body_len: usize, } impl Io { - pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self { - Io { - id, - io, - read_state: ReadState::Init, - write_state: WriteState::Init, - max_body_len: max_frame_body_len, - } - } + pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self { + Io { + id, + io, + read_state: ReadState::Init, + write_state: WriteState::Init, + max_body_len: max_frame_body_len, + } + } } /// The stages of writing a new `Frame`. enum WriteState { - Init, - Header { header: [u8; header::HEADER_SIZE], buffer: Vec, offset: usize }, - Body { buffer: Vec, offset: usize }, + Init, + Header { + header: [u8; header::HEADER_SIZE], + buffer: Vec, + offset: usize, + }, + Body { + buffer: Vec, + offset: usize, + }, } impl fmt::Debug for WriteState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WriteState::Init => f.write_str("(WriteState::Init)"), - WriteState::Header { offset, .. } => { - write!(f, "(WriteState::Header (offset {}))", offset) - }, - WriteState::Body { offset, buffer } => { - write!(f, "(WriteState::Body (offset {}) (buffer-len {}))", offset, buffer.len()) - }, - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WriteState::Init => f.write_str("(WriteState::Init)"), + WriteState::Header { offset, .. } => { + write!(f, "(WriteState::Header (offset {}))", offset) + } + WriteState::Body { offset, buffer } => { + write!( + f, + "(WriteState::Body (offset {}) (buffer-len {}))", + offset, + buffer.len() + ) + } + } + } } impl Sink> for Io { - type Error = io::Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - loop { - tracing::trace!(target: LOG_TARGET, "{}: write: {:?}", this.id, this.write_state); - match &mut this.write_state { - WriteState::Init => return Poll::Ready(Ok(())), - WriteState::Header { header, buffer, ref mut offset } => - match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == header.len() { - if !buffer.is_empty() { - let buffer = std::mem::take(buffer); - this.write_state = WriteState::Body { buffer, offset: 0 }; - } else { - this.write_state = WriteState::Init; - } - } - }, - }, - WriteState::Body { buffer, ref mut offset } => - match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(n)) => { - if n == 0 { - return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); - } - *offset += n; - if *offset == buffer.len() { - this.write_state = WriteState::Init; - } - }, - }, - } - } - } - - fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> { - let header = header::encode(&f.header); - let buffer = f.body; - self.get_mut().write_state = WriteState::Header { header, buffer, offset: 0 }; - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - ready!(this.poll_ready_unpin(cx))?; - Pin::new(&mut this.io).poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = Pin::into_inner(self); - ready!(this.poll_ready_unpin(cx))?; - Pin::new(&mut this.io).poll_close(cx) - } + type Error = io::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + loop { + tracing::trace!(target: LOG_TARGET, "{}: write: {:?}", this.id, this.write_state); + match &mut this.write_state { + WriteState::Init => return Poll::Ready(Ok(())), + WriteState::Header { + header, + buffer, + ref mut offset, + } => match Pin::new(&mut this.io).poll_write(cx, &header[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(n)) => { + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + *offset += n; + if *offset == header.len() { + if !buffer.is_empty() { + let buffer = std::mem::take(buffer); + this.write_state = WriteState::Body { buffer, offset: 0 }; + } else { + this.write_state = WriteState::Init; + } + } + } + }, + WriteState::Body { + buffer, + ref mut offset, + } => match Pin::new(&mut this.io).poll_write(cx, &buffer[*offset..]) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(n)) => { + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + *offset += n; + if *offset == buffer.len() { + this.write_state = WriteState::Init; + } + } + }, + } + } + } + + fn start_send(self: Pin<&mut Self>, f: Frame<()>) -> Result<(), Self::Error> { + let header = header::encode(&f.header); + let buffer = f.body; + self.get_mut().write_state = WriteState::Header { + header, + buffer, + offset: 0, + }; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + ready!(this.poll_ready_unpin(cx))?; + Pin::new(&mut this.io).poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = Pin::into_inner(self); + ready!(this.poll_ready_unpin(cx))?; + Pin::new(&mut this.io).poll_close(cx) + } } /// The stages of reading a new `Frame`. enum ReadState { - /// Initial reading state. - Init, - /// Reading the frame header. - Header { offset: usize, buffer: [u8; header::HEADER_SIZE] }, - /// Reading the frame body. - Body { header: header::Header<()>, offset: usize, buffer: Vec }, + /// Initial reading state. + Init, + /// Reading the frame header. + Header { + offset: usize, + buffer: [u8; header::HEADER_SIZE], + }, + /// Reading the frame body. + Body { + header: header::Header<()>, + offset: usize, + buffer: Vec, + }, } impl Stream for Io { - type Item = Result, FrameDecodeError>; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = &mut *self; - loop { - tracing::trace!(target: LOG_TARGET, "{}: read: {:?}", this.id, this.read_state); - match this.read_state { - ReadState::Init => { - this.read_state = - ReadState::Header { offset: 0, buffer: [0; header::HEADER_SIZE] }; - }, - ReadState::Header { ref mut offset, ref mut buffer } => { - if *offset == header::HEADER_SIZE { - let header = match header::decode(buffer) { - Ok(hd) => hd, - Err(e) => return Poll::Ready(Some(Err(e.into()))), - }; - - tracing::trace!(target: LOG_TARGET, "{}: read: {}", this.id, header); - - if header.tag() != header::Tag::Data { - this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame::new(header)))); - } - - let body_len = header.len().val() as usize; - - if body_len > this.max_body_len { - return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( - body_len, - )))); - } - - this.read_state = - ReadState::Body { header, offset: 0, buffer: vec![0; body_len] }; - - continue; - } - - let buf = &mut buffer[*offset..header::HEADER_SIZE]; - match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { - 0 => { - if *offset == 0 { - return Poll::Ready(None); - } - let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); - return Poll::Ready(Some(Err(e))); - }, - n => *offset += n, - } - }, - ReadState::Body { ref header, ref mut offset, ref mut buffer } => { - let body_len = header.len().val() as usize; - - if *offset == body_len { - let h = header.clone(); - let v = std::mem::take(buffer); - this.read_state = ReadState::Init; - return Poll::Ready(Some(Ok(Frame { header: h, body: v }))); - } - - let buf = &mut buffer[*offset..body_len]; - match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { - 0 => { - let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); - return Poll::Ready(Some(Err(e))); - }, - n => *offset += n, - } - }, - } - } - } + type Item = Result, FrameDecodeError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = &mut *self; + loop { + tracing::trace!(target: LOG_TARGET, "{}: read: {:?}", this.id, this.read_state); + match this.read_state { + ReadState::Init => { + this.read_state = ReadState::Header { + offset: 0, + buffer: [0; header::HEADER_SIZE], + }; + } + ReadState::Header { + ref mut offset, + ref mut buffer, + } => { + if *offset == header::HEADER_SIZE { + let header = match header::decode(buffer) { + Ok(hd) => hd, + Err(e) => return Poll::Ready(Some(Err(e.into()))), + }; + + tracing::trace!(target: LOG_TARGET, "{}: read: {}", this.id, header); + + if header.tag() != header::Tag::Data { + this.read_state = ReadState::Init; + return Poll::Ready(Some(Ok(Frame::new(header)))); + } + + let body_len = header.len().val() as usize; + + if body_len > this.max_body_len { + return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge( + body_len, + )))); + } + + this.read_state = ReadState::Body { + header, + offset: 0, + buffer: vec![0; body_len], + }; + + continue; + } + + let buf = &mut buffer[*offset..header::HEADER_SIZE]; + match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { + 0 => { + if *offset == 0 { + return Poll::Ready(None); + } + let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); + return Poll::Ready(Some(Err(e))); + } + n => *offset += n, + } + } + ReadState::Body { + ref header, + ref mut offset, + ref mut buffer, + } => { + let body_len = header.len().val() as usize; + + if *offset == body_len { + let h = header.clone(); + let v = std::mem::take(buffer); + this.read_state = ReadState::Init; + return Poll::Ready(Some(Ok(Frame { header: h, body: v }))); + } + + let buf = &mut buffer[*offset..body_len]; + match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? { + 0 => { + let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into()); + return Poll::Ready(Some(Err(e))); + } + n => *offset += n, + } + } + } + } + } } impl fmt::Debug for ReadState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ReadState::Init => f.write_str("(ReadState::Init)"), - ReadState::Header { offset, .. } => { - write!(f, "(ReadState::Header (offset {}))", offset) - }, - ReadState::Body { header, offset, buffer } => { - write!( - f, - "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", - header, - offset, - buffer.len() - ) - }, - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ReadState::Init => f.write_str("(ReadState::Init)"), + ReadState::Header { offset, .. } => { + write!(f, "(ReadState::Header (offset {}))", offset) + } + ReadState::Body { + header, + offset, + buffer, + } => { + write!( + f, + "(ReadState::Body (header {}) (offset {}) (buffer-len {}))", + header, + offset, + buffer.len() + ) + } + } + } } /// Possible errors while decoding a message frame. #[non_exhaustive] #[derive(Debug)] pub enum FrameDecodeError { - /// An I/O error. - Io(io::Error), - /// Decoding the frame header failed. - Header(HeaderDecodeError), - /// A data frame body length is larger than the configured maximum. - FrameTooLarge(usize), + /// An I/O error. + Io(io::Error), + /// Decoding the frame header failed. + Header(HeaderDecodeError), + /// A data frame body length is larger than the configured maximum. + FrameTooLarge(usize), } impl std::fmt::Display for FrameDecodeError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e), - FrameDecodeError::Header(e) => write!(f, "decode error: {}", e), - FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n), - } - } + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e), + FrameDecodeError::Header(e) => write!(f, "decode error: {}", e), + FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n), + } + } } impl std::error::Error for FrameDecodeError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - FrameDecodeError::Io(e) => Some(e), - FrameDecodeError::Header(e) => Some(e), - FrameDecodeError::FrameTooLarge(_) => None, - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + FrameDecodeError::Io(e) => Some(e), + FrameDecodeError::Header(e) => Some(e), + FrameDecodeError::FrameTooLarge(_) => None, + } + } } impl From for FrameDecodeError { - fn from(e: std::io::Error) -> Self { - FrameDecodeError::Io(e) - } + fn from(e: std::io::Error) -> Self { + FrameDecodeError::Io(e) + } } impl From for FrameDecodeError { - fn from(e: HeaderDecodeError) -> Self { - FrameDecodeError::Header(e) - } + fn from(e: HeaderDecodeError) -> Self { + FrameDecodeError::Header(e) + } } #[cfg(test)] mod tests { - use super::*; - use quickcheck::{Arbitrary, Gen, QuickCheck}; - use rand::RngCore; - - impl Arbitrary for Frame<()> { - fn arbitrary(g: &mut Gen) -> Self { - let mut header: header::Header<()> = Arbitrary::arbitrary(g); - let body = if header.tag() == header::Tag::Data { - header.set_len(header.len().val() % 4096); - let mut b = vec![0; header.len().val() as usize]; - rand::thread_rng().fill_bytes(&mut b); - b - } else { - Vec::new() - }; - Frame { header, body } - } - } - - #[test] - fn encode_decode_identity() { - fn property(f: Frame<()>) -> bool { - futures::executor::block_on(async move { - let id = crate::yamux::connection::Id::random(); - let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); - if io.send(f.clone()).await.is_err() { - return false; - } - if io.flush().await.is_err() { - return false; - } - io.io.set_position(0); - if let Ok(Some(x)) = io.try_next().await { - x == f - } else { - false - } - }) - } - - QuickCheck::new().tests(10_000).quickcheck(property as fn(Frame<()>) -> bool) - } + use super::*; + use quickcheck::{Arbitrary, Gen, QuickCheck}; + use rand::RngCore; + + impl Arbitrary for Frame<()> { + fn arbitrary(g: &mut Gen) -> Self { + let mut header: header::Header<()> = Arbitrary::arbitrary(g); + let body = if header.tag() == header::Tag::Data { + header.set_len(header.len().val() % 4096); + let mut b = vec![0; header.len().val() as usize]; + rand::thread_rng().fill_bytes(&mut b); + b + } else { + Vec::new() + }; + Frame { header, body } + } + } + + #[test] + fn encode_decode_identity() { + fn property(f: Frame<()>) -> bool { + futures::executor::block_on(async move { + let id = crate::yamux::connection::Id::random(); + let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len()); + if io.send(f.clone()).await.is_err() { + return false; + } + if io.flush().await.is_err() { + return false; + } + io.io.set_position(0); + if let Ok(Some(x)) = io.try_next().await { + x == f + } else { + false + } + }) + } + + QuickCheck::new().tests(10_000).quickcheck(property as fn(Frame<()>) -> bool) + } } diff --git a/src/yamux/mod.rs b/src/yamux/mod.rs index 4ea6f4fd..2671e937 100644 --- a/src/yamux/mod.rs +++ b/src/yamux/mod.rs @@ -32,13 +32,13 @@ pub(crate) mod connection; mod tagged_stream; pub use crate::yamux::{ - connection::{Connection, Mode, Packet, Stream}, - control::{Control, ControlledConnection}, - error::ConnectionError, - frame::{ - header::{HeaderDecodeError, StreamId}, - FrameDecodeError, - }, + connection::{Connection, Mode, Packet, Stream}, + control::{Control, ControlledConnection}, + error::ConnectionError, + frame::{ + header::{HeaderDecodeError, StreamId}, + FrameDecodeError, + }, }; pub const DEFAULT_CREDIT: u32 = 256 * 1024; // as per yamux specification @@ -68,27 +68,27 @@ const DEFAULT_SPLIT_SEND_SIZE: usize = 16 * 1024; /// Specifies when window update frames are sent. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum WindowUpdateMode { - /// Send window updates as soon as a [`Stream`]'s receive window drops to 0. - /// - /// This ensures that the sender can resume sending more data as soon as possible - /// but a slow reader on the receiving side may be overwhelmed, i.e. it accumulates - /// data in its buffer which may reach its limit (see `set_max_buffer_size`). - /// In this mode, window updates merely prevent head of line blocking but do not - /// effectively exercise back pressure on senders. - OnReceive, - - /// Send window updates only when data is read on the receiving end. - /// - /// This ensures that senders do not overwhelm receivers and keeps buffer usage - /// low. However, depending on the protocol, there is a risk of deadlock, namely - /// if both endpoints want to send data larger than the receivers window and they - /// do not read before finishing their writes. Use this mode only if you are sure - /// that this will never happen, i.e. if - /// - /// - Endpoints *A* and *B* never write at the same time, *or* - /// - Endpoints *A* and *B* write at most *n* frames concurrently such that the sum of the - /// frame lengths is less or equal to the available credit of *A* and *B* respectively. - OnRead, + /// Send window updates as soon as a [`Stream`]'s receive window drops to 0. + /// + /// This ensures that the sender can resume sending more data as soon as possible + /// but a slow reader on the receiving side may be overwhelmed, i.e. it accumulates + /// data in its buffer which may reach its limit (see `set_max_buffer_size`). + /// In this mode, window updates merely prevent head of line blocking but do not + /// effectively exercise back pressure on senders. + OnReceive, + + /// Send window updates only when data is read on the receiving end. + /// + /// This ensures that senders do not overwhelm receivers and keeps buffer usage + /// low. However, depending on the protocol, there is a risk of deadlock, namely + /// if both endpoints want to send data larger than the receivers window and they + /// do not read before finishing their writes. Use this mode only if you are sure + /// that this will never happen, i.e. if + /// + /// - Endpoints *A* and *B* never write at the same time, *or* + /// - Endpoints *A* and *B* write at most *n* frames concurrently such that the sum of the + /// frame lengths is less or equal to the available credit of *A* and *B* respectively. + OnRead, } /// Yamux configuration. @@ -103,78 +103,78 @@ pub enum WindowUpdateMode { /// - split send size = 16 KiB #[derive(Debug, Clone)] pub struct Config { - receive_window: u32, - max_buffer_size: usize, - max_num_streams: usize, - window_update_mode: WindowUpdateMode, - read_after_close: bool, - split_send_size: usize, + receive_window: u32, + max_buffer_size: usize, + max_num_streams: usize, + window_update_mode: WindowUpdateMode, + read_after_close: bool, + split_send_size: usize, } impl Default for Config { - fn default() -> Self { - Config { - receive_window: DEFAULT_CREDIT, - max_buffer_size: 1024 * 1024, - max_num_streams: 8192, - window_update_mode: WindowUpdateMode::OnRead, - read_after_close: true, - split_send_size: DEFAULT_SPLIT_SEND_SIZE, - } - } + fn default() -> Self { + Config { + receive_window: DEFAULT_CREDIT, + max_buffer_size: 1024 * 1024, + max_num_streams: 8192, + window_update_mode: WindowUpdateMode::OnRead, + read_after_close: true, + split_send_size: DEFAULT_SPLIT_SEND_SIZE, + } + } } impl Config { - /// Set the receive window per stream (must be >= 256 KiB). - /// - /// # Panics - /// - /// If the given receive window is < 256 KiB. - pub fn set_receive_window(&mut self, n: u32) -> &mut Self { - assert!(n >= DEFAULT_CREDIT); - self.receive_window = n; - self - } - - /// Set the max. buffer size per stream. - pub fn set_max_buffer_size(&mut self, n: usize) -> &mut Self { - self.max_buffer_size = n; - self - } - - /// Set the max. number of streams. - pub fn set_max_num_streams(&mut self, n: usize) -> &mut Self { - self.max_num_streams = n; - self - } - - /// Set the window update mode to use. - pub fn set_window_update_mode(&mut self, m: WindowUpdateMode) -> &mut Self { - self.window_update_mode = m; - self - } - - /// Allow or disallow streams to read from buffered data after - /// the connection has been closed. - pub fn set_read_after_close(&mut self, b: bool) -> &mut Self { - self.read_after_close = b; - self - } - - /// Set the max. payload size used when sending data frames. Payloads larger - /// than the configured max. will be split. - pub fn set_split_send_size(&mut self, n: usize) -> &mut Self { - self.split_send_size = n; - self - } + /// Set the receive window per stream (must be >= 256 KiB). + /// + /// # Panics + /// + /// If the given receive window is < 256 KiB. + pub fn set_receive_window(&mut self, n: u32) -> &mut Self { + assert!(n >= DEFAULT_CREDIT); + self.receive_window = n; + self + } + + /// Set the max. buffer size per stream. + pub fn set_max_buffer_size(&mut self, n: usize) -> &mut Self { + self.max_buffer_size = n; + self + } + + /// Set the max. number of streams. + pub fn set_max_num_streams(&mut self, n: usize) -> &mut Self { + self.max_num_streams = n; + self + } + + /// Set the window update mode to use. + pub fn set_window_update_mode(&mut self, m: WindowUpdateMode) -> &mut Self { + self.window_update_mode = m; + self + } + + /// Allow or disallow streams to read from buffered data after + /// the connection has been closed. + pub fn set_read_after_close(&mut self, b: bool) -> &mut Self { + self.read_after_close = b; + self + } + + /// Set the max. payload size used when sending data frames. Payloads larger + /// than the configured max. will be split. + pub fn set_split_send_size(&mut self, n: usize) -> &mut Self { + self.split_send_size = n; + self + } } // Check that we can safely cast a `usize` to a `u64`. static_assertions::const_assert! { - std::mem::size_of::() <= std::mem::size_of::() + std::mem::size_of::() <= std::mem::size_of::() } // Check that we can safely cast a `u32` to a `usize`. static_assertions::const_assert! { - std::mem::size_of::() <= std::mem::size_of::() + std::mem::size_of::() <= std::mem::size_of::() } diff --git a/src/yamux/tagged_stream.rs b/src/yamux/tagged_stream.rs index 91b50135..5583a5b7 100644 --- a/src/yamux/tagged_stream.rs +++ b/src/yamux/tagged_stream.rs @@ -1,50 +1,54 @@ use futures::Stream; use std::{ - pin::Pin, - task::{Context, Poll}, + pin::Pin, + task::{Context, Poll}, }; /// A stream that yields its tag with every item. #[pin_project::pin_project] pub struct TaggedStream { - key: K, - #[pin] - inner: S, + key: K, + #[pin] + inner: S, - reported_none: bool, + reported_none: bool, } impl TaggedStream { - pub fn new(key: K, inner: S) -> Self { - Self { key, inner, reported_none: false } - } - - pub fn inner_mut(&mut self) -> &mut S { - &mut self.inner - } + pub fn new(key: K, inner: S) -> Self { + Self { + key, + inner, + reported_none: false, + } + } + + pub fn inner_mut(&mut self) -> &mut S { + &mut self.inner + } } impl Stream for TaggedStream where - K: Copy, - S: Stream, + K: Copy, + S: Stream, { - type Item = (K, Option); + type Item = (K, Option); - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); - if *this.reported_none { - return Poll::Ready(None); - } + if *this.reported_none { + return Poll::Ready(None); + } - match futures::ready!(this.inner.poll_next(cx)) { - Some(item) => Poll::Ready(Some((*this.key, Some(item)))), - None => { - *this.reported_none = true; + match futures::ready!(this.inner.poll_next(cx)) { + Some(item) => Poll::Ready(Some((*this.key, Some(item)))), + None => { + *this.reported_none = true; - Poll::Ready(Some((*this.key, None))) - }, - } - } + Poll::Ready(Some((*this.key, None))) + } + } + } } diff --git a/tests/conformance/rust/identify.rs b/tests/conformance/rust/identify.rs index fedc7674..c29ca882 100644 --- a/tests/conformance/rust/identify.rs +++ b/tests/conformance/rust/identify.rs @@ -23,155 +23,155 @@ use futures::{Stream, StreamExt}; use libp2p::{ - identify, identity, ping, - swarm::{NetworkBehaviour, SwarmBuilder, SwarmEvent}, - PeerId, Swarm, + identify, identity, ping, + swarm::{NetworkBehaviour, SwarmBuilder, SwarmEvent}, + PeerId, Swarm, }; use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::libp2p::{ - identify::{Config as IdentifyConfig, IdentifyEvent}, - ping::{Config as PingConfig, PingEvent}, - }, - transport::tcp::config::Config as TcpConfig, - Litep2p, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::libp2p::{ + identify::{Config as IdentifyConfig, IdentifyEvent}, + ping::{Config as PingConfig, PingEvent}, + }, + transport::tcp::config::Config as TcpConfig, + Litep2p, }; // We create a custom network behaviour that combines gossipsub, ping and identify. #[derive(NetworkBehaviour)] #[behaviour(out_event = "MyBehaviourEvent")] struct MyBehaviour { - identify: identify::Behaviour, - ping: ping::Behaviour, + identify: identify::Behaviour, + ping: ping::Behaviour, } enum MyBehaviourEvent { - Identify(identify::Event), - Ping(ping::Event), + Identify(identify::Event), + Ping(ping::Event), } impl From for MyBehaviourEvent { - fn from(event: identify::Event) -> Self { - MyBehaviourEvent::Identify(event) - } + fn from(event: identify::Event) -> Self { + MyBehaviourEvent::Identify(event) + } } impl From for MyBehaviourEvent { - fn from(event: ping::Event) -> Self { - MyBehaviourEvent::Ping(event) - } + fn from(event: ping::Event) -> Self { + MyBehaviourEvent::Ping(event) + } } // initialize litep2p with ping support fn initialize_litep2p() -> ( - Litep2p, - Box + Send + Unpin>, - Box + Send + Unpin>, + Litep2p, + Box + Send + Unpin>, + Box + Send + Unpin>, ) { - let keypair = Keypair::generate(); - let (ping_config, ping_event_stream) = PingConfig::default(); - let (identify_config, identify_event_stream) = - IdentifyConfig::new("proto v1".to_string(), None, Vec::new()); - - let litep2p = Litep2p::new( - ConfigBuilder::new() - .with_keypair(keypair) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config) - .with_libp2p_identify(identify_config) - .build(), - ) - .unwrap(); - - (litep2p, ping_event_stream, identify_event_stream) + let keypair = Keypair::generate(); + let (ping_config, ping_event_stream) = PingConfig::default(); + let (identify_config, identify_event_stream) = + IdentifyConfig::new("proto v1".to_string(), None, Vec::new()); + + let litep2p = Litep2p::new( + ConfigBuilder::new() + .with_keypair(keypair) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config) + .with_libp2p_identify(identify_config) + .build(), + ) + .unwrap(); + + (litep2p, ping_event_stream, identify_event_stream) } fn initialize_libp2p() -> Swarm { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); - tracing::debug!("Local peer id: {local_peer_id:?}"); + tracing::debug!("Local peer id: {local_peer_id:?}"); - let transport = libp2p::tokio_development_transport(local_key.clone()).unwrap(); - let behaviour = MyBehaviour { - identify: identify::Behaviour::new( - identify::Config::new("/ipfs/1.0.0".into(), local_key.public()) - .with_agent_version("libp2p agent".to_string()), - ), - ping: Default::default(), - }; - let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); + let transport = libp2p::tokio_development_transport(local_key.clone()).unwrap(); + let behaviour = MyBehaviour { + identify: identify::Behaviour::new( + identify::Config::new("/ipfs/1.0.0".into(), local_key.public()) + .with_agent_version("libp2p agent".to_string()), + ), + ping: Default::default(), + }; + let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - swarm + swarm } #[tokio::test] async fn identify_works() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut libp2p = initialize_libp2p(); - let (mut litep2p, _ping_event_stream, mut identify_event_stream) = initialize_litep2p(); - let address = litep2p.listen_addresses().next().unwrap().clone(); - - libp2p.dial(address).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p.next_event().await; - } - }); - - let mut libp2p_done = false; - let mut litep2p_done = false; - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - SwarmEvent::Behaviour(MyBehaviourEvent::Ping(_event)) => {}, - SwarmEvent::Behaviour(MyBehaviourEvent::Identify(event)) => match event { - identify::Event::Received { info, .. } => { - libp2p_done = true; - - assert_eq!(info.protocol_version, "proto v1"); - assert_eq!(info.agent_version, "litep2p/1.0.0"); - - if libp2p_done && litep2p_done { - break - } - } - _ => {} - } - _ => {} - } - }, - event = identify_event_stream.next() => match event { - Some(IdentifyEvent::PeerIdentified { protocol_version, user_agent, .. }) => { - litep2p_done = true; - - assert_eq!(protocol_version, Some("/ipfs/1.0.0".to_string())); - assert_eq!(user_agent, Some("libp2p agent".to_string())); - - if libp2p_done && litep2p_done { - break - } - } - None => panic!("identify exited"), - }, - _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { - panic!("failed to receive identify in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut libp2p = initialize_libp2p(); + let (mut litep2p, _ping_event_stream, mut identify_event_stream) = initialize_litep2p(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + libp2p.dial(address).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p.next_event().await; + } + }); + + let mut libp2p_done = false; + let mut litep2p_done = false; + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + SwarmEvent::Behaviour(MyBehaviourEvent::Ping(_event)) => {}, + SwarmEvent::Behaviour(MyBehaviourEvent::Identify(event)) => match event { + identify::Event::Received { info, .. } => { + libp2p_done = true; + + assert_eq!(info.protocol_version, "proto v1"); + assert_eq!(info.agent_version, "litep2p/1.0.0"); + + if libp2p_done && litep2p_done { + break + } + } + _ => {} + } + _ => {} + } + }, + event = identify_event_stream.next() => match event { + Some(IdentifyEvent::PeerIdentified { protocol_version, user_agent, .. }) => { + litep2p_done = true; + + assert_eq!(protocol_version, Some("/ipfs/1.0.0".to_string())); + assert_eq!(user_agent, Some("libp2p agent".to_string())); + + if libp2p_done && litep2p_done { + break + } + } + None => panic!("identify exited"), + }, + _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { + panic!("failed to receive identify in time"); + } + } + } } diff --git a/tests/conformance/rust/kademlia.rs b/tests/conformance/rust/kademlia.rs index 64a0bcb8..a501b581 100644 --- a/tests/conformance/rust/kademlia.rs +++ b/tests/conformance/rust/kademlia.rs @@ -21,385 +21,389 @@ use futures::StreamExt; use libp2p::{ - identify, identity, - kad::{self, store::RecordStore}, - swarm::{keep_alive, NetworkBehaviour, SwarmBuilder, SwarmEvent}, - PeerId, Swarm, + identify, identity, + kad::{self, store::RecordStore}, + swarm::{keep_alive, NetworkBehaviour, SwarmBuilder, SwarmEvent}, + PeerId, Swarm, }; use litep2p::{ - config::ConfigBuilder as Litep2pConfigBuilder, - crypto::ed25519::Keypair, - protocol::libp2p::kademlia::{ - ConfigBuilder, KademliaEvent, KademliaHandle, Quorum, Record, RecordKey, - }, - transport::tcp::config::Config as TcpConfig, - Litep2p, + config::ConfigBuilder as Litep2pConfigBuilder, + crypto::ed25519::Keypair, + protocol::libp2p::kademlia::{ + ConfigBuilder, KademliaEvent, KademliaHandle, Quorum, Record, RecordKey, + }, + transport::tcp::config::Config as TcpConfig, + Litep2p, }; use multiaddr::Protocol; #[derive(NetworkBehaviour)] struct Behaviour { - keep_alive: keep_alive::Behaviour, - kad: kad::Kademlia, - identify: identify::Behaviour, + keep_alive: keep_alive::Behaviour, + kad: kad::Kademlia, + identify: identify::Behaviour, } // initialize litep2p with ping support fn initialize_litep2p() -> (Litep2p, KademliaHandle) { - let keypair = Keypair::generate(); - let (kad_config, kad_handle) = ConfigBuilder::new().build(); - - let litep2p = Litep2p::new( - Litep2pConfigBuilder::new() - .with_keypair(keypair) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_kademlia(kad_config) - .build(), - ) - .unwrap(); - - (litep2p, kad_handle) + let keypair = Keypair::generate(); + let (kad_config, kad_handle) = ConfigBuilder::new().build(); + + let litep2p = Litep2p::new( + Litep2pConfigBuilder::new() + .with_keypair(keypair) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config) + .build(), + ) + .unwrap(); + + (litep2p, kad_handle) } fn initialize_libp2p() -> Swarm { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); - - tracing::debug!("Local peer id: {local_peer_id:?}"); - - let transport = libp2p::tokio_development_transport(local_key.clone()).unwrap(); - let behaviour = { - let config = kad::KademliaConfig::default(); - let store = kad::store::MemoryStore::new(local_peer_id); - - Behaviour { - kad: kad::Kademlia::with_config(local_peer_id, store, config), - keep_alive: Default::default(), - identify: identify::Behaviour::new(identify::Config::new( - "/ipfs/1.0.0".into(), - local_key.public(), - )), - } - }; - let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); - - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - - swarm + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); + + tracing::debug!("Local peer id: {local_peer_id:?}"); + + let transport = libp2p::tokio_development_transport(local_key.clone()).unwrap(); + let behaviour = { + let config = kad::KademliaConfig::default(); + let store = kad::store::MemoryStore::new(local_peer_id); + + Behaviour { + kad: kad::Kademlia::with_config(local_peer_id, store, config), + keep_alive: Default::default(), + identify: identify::Behaviour::new(identify::Config::new( + "/ipfs/1.0.0".into(), + local_key.public(), + )), + } + }; + let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); + + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + + swarm } #[tokio::test] async fn find_node() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut addresses = vec![]; - let mut peer_ids = vec![]; - for _ in 0..3 { - let mut libp2p = initialize_libp2p(); - - loop { - if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { - addresses.push(address); - peer_ids.push(*libp2p.local_peer_id()); - break; - } - } - - tokio::spawn(async move { - loop { - let _ = libp2p.select_next_some().await; - } - }); - } - - let mut libp2p = initialize_libp2p(); - let (mut litep2p, mut kad_handle) = initialize_litep2p(); - let address = litep2p.listen_addresses().next().unwrap().clone(); - - for i in 0..addresses.len() { - libp2p.dial(addresses[i].clone()).unwrap(); - let _ = libp2p.behaviour_mut().kad.add_address(&peer_ids[i], addresses[i].clone()); - } - libp2p.dial(address).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p.next_event().await; - } - }); - - #[allow(unused)] - let mut listen_addr = None; - let peer_id = *libp2p.local_peer_id(); - - tracing::error!("local peer id: {peer_id}"); - - loop { - if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { - listen_addr = Some(address); - break; - } - } - - tokio::spawn(async move { - loop { - let _ = libp2p.select_next_some().await; - } - }); - - tokio::time::sleep(std::time::Duration::from_secs(3)).await; - let listen_addr = listen_addr.unwrap().with(Protocol::P2p(peer_id.into())); - - kad_handle - .add_known_peer( - litep2p::PeerId::from_bytes(&peer_id.to_bytes()).unwrap(), - vec![listen_addr], - ) - .await; - - let target = litep2p::PeerId::random(); - let _ = kad_handle.find_node(target).await; - - loop { - match kad_handle.next().await { - Some(KademliaEvent::FindNodeSuccess { target: query_target, peers, .. }) => { - assert_eq!(target, query_target); - assert!(!peers.is_empty()); - break; - }, - _ => {}, - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut addresses = vec![]; + let mut peer_ids = vec![]; + for _ in 0..3 { + let mut libp2p = initialize_libp2p(); + + loop { + if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { + addresses.push(address); + peer_ids.push(*libp2p.local_peer_id()); + break; + } + } + + tokio::spawn(async move { + loop { + let _ = libp2p.select_next_some().await; + } + }); + } + + let mut libp2p = initialize_libp2p(); + let (mut litep2p, mut kad_handle) = initialize_litep2p(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + for i in 0..addresses.len() { + libp2p.dial(addresses[i].clone()).unwrap(); + let _ = libp2p.behaviour_mut().kad.add_address(&peer_ids[i], addresses[i].clone()); + } + libp2p.dial(address).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p.next_event().await; + } + }); + + #[allow(unused)] + let mut listen_addr = None; + let peer_id = *libp2p.local_peer_id(); + + tracing::error!("local peer id: {peer_id}"); + + loop { + if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { + listen_addr = Some(address); + break; + } + } + + tokio::spawn(async move { + loop { + let _ = libp2p.select_next_some().await; + } + }); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + let listen_addr = listen_addr.unwrap().with(Protocol::P2p(peer_id.into())); + + kad_handle + .add_known_peer( + litep2p::PeerId::from_bytes(&peer_id.to_bytes()).unwrap(), + vec![listen_addr], + ) + .await; + + let target = litep2p::PeerId::random(); + let _ = kad_handle.find_node(target).await; + + loop { + match kad_handle.next().await { + Some(KademliaEvent::FindNodeSuccess { + target: query_target, + peers, + .. + }) => { + assert_eq!(target, query_target); + assert!(!peers.is_empty()); + break; + } + _ => {} + } + } } #[tokio::test] async fn put_record() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut addresses = vec![]; - let mut peer_ids = vec![]; - let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0usize)); - - for _ in 0..3 { - let mut libp2p = initialize_libp2p(); - - loop { - if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { - addresses.push(address); - peer_ids.push(*libp2p.local_peer_id()); - break; - } - } - - let counter_copy = std::sync::Arc::clone(&counter); - tokio::spawn(async move { - let mut record_found = false; - - loop { - tokio::select! { - _ = libp2p.select_next_some() => {} - _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { - let store = libp2p.behaviour_mut().kad.store_mut(); - if store.get(&libp2p::kad::record::Key::new(&vec![1, 2, 3, 4])).is_some() && !record_found { - counter_copy.fetch_add(1usize, std::sync::atomic::Ordering::SeqCst); - record_found = true; - } - } - } - } - }); - } - - let mut libp2p = initialize_libp2p(); - let (mut litep2p, mut kad_handle) = initialize_litep2p(); - let address = litep2p.listen_addresses().next().unwrap().clone(); - - for i in 0..addresses.len() { - libp2p.dial(addresses[i].clone()).unwrap(); - let _ = libp2p.behaviour_mut().kad.add_address(&peer_ids[i], addresses[i].clone()); - } - libp2p.dial(address).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p.next_event().await; - } - }); - - #[allow(unused)] - let mut listen_addr = None; - let peer_id = *libp2p.local_peer_id(); - - tracing::error!("local peer id: {peer_id}"); - - loop { - if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { - listen_addr = Some(address); - break; - } - } - - let counter_copy = std::sync::Arc::clone(&counter); - tokio::spawn(async move { - let mut record_found = false; - - loop { - tokio::select! { - _ = libp2p.select_next_some() => {} - _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { - let store = libp2p.behaviour_mut().kad.store_mut(); - if store.get(&libp2p::kad::record::Key::new(&vec![1, 2, 3, 4])).is_some() && !record_found { - counter_copy.fetch_add(1usize, std::sync::atomic::Ordering::SeqCst); - record_found = true; - } - } - } - } - }); - - tokio::time::sleep(std::time::Duration::from_secs(3)).await; - - let listen_addr = listen_addr.unwrap().with(Protocol::P2p(peer_id.into())); - - kad_handle - .add_known_peer( - litep2p::PeerId::from_bytes(&peer_id.to_bytes()).unwrap(), - vec![listen_addr], - ) - .await; - - let record_key = RecordKey::new(&vec![1, 2, 3, 4]); - let record = Record::new(record_key, vec![1, 3, 3, 7, 1, 3, 3, 8]); - - let _ = kad_handle.put_record(record).await; - - loop { - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - if counter.load(std::sync::atomic::Ordering::SeqCst) == 4 { - break; - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut addresses = vec![]; + let mut peer_ids = vec![]; + let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0usize)); + + for _ in 0..3 { + let mut libp2p = initialize_libp2p(); + + loop { + if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { + addresses.push(address); + peer_ids.push(*libp2p.local_peer_id()); + break; + } + } + + let counter_copy = std::sync::Arc::clone(&counter); + tokio::spawn(async move { + let mut record_found = false; + + loop { + tokio::select! { + _ = libp2p.select_next_some() => {} + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + let store = libp2p.behaviour_mut().kad.store_mut(); + if store.get(&libp2p::kad::record::Key::new(&vec![1, 2, 3, 4])).is_some() && !record_found { + counter_copy.fetch_add(1usize, std::sync::atomic::Ordering::SeqCst); + record_found = true; + } + } + } + } + }); + } + + let mut libp2p = initialize_libp2p(); + let (mut litep2p, mut kad_handle) = initialize_litep2p(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + for i in 0..addresses.len() { + libp2p.dial(addresses[i].clone()).unwrap(); + let _ = libp2p.behaviour_mut().kad.add_address(&peer_ids[i], addresses[i].clone()); + } + libp2p.dial(address).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p.next_event().await; + } + }); + + #[allow(unused)] + let mut listen_addr = None; + let peer_id = *libp2p.local_peer_id(); + + tracing::error!("local peer id: {peer_id}"); + + loop { + if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { + listen_addr = Some(address); + break; + } + } + + let counter_copy = std::sync::Arc::clone(&counter); + tokio::spawn(async move { + let mut record_found = false; + + loop { + tokio::select! { + _ = libp2p.select_next_some() => {} + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + let store = libp2p.behaviour_mut().kad.store_mut(); + if store.get(&libp2p::kad::record::Key::new(&vec![1, 2, 3, 4])).is_some() && !record_found { + counter_copy.fetch_add(1usize, std::sync::atomic::Ordering::SeqCst); + record_found = true; + } + } + } + } + }); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + let listen_addr = listen_addr.unwrap().with(Protocol::P2p(peer_id.into())); + + kad_handle + .add_known_peer( + litep2p::PeerId::from_bytes(&peer_id.to_bytes()).unwrap(), + vec![listen_addr], + ) + .await; + + let record_key = RecordKey::new(&vec![1, 2, 3, 4]); + let record = Record::new(record_key, vec![1, 3, 3, 7, 1, 3, 3, 8]); + + let _ = kad_handle.put_record(record).await; + + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + if counter.load(std::sync::atomic::Ordering::SeqCst) == 4 { + break; + } + } } #[tokio::test] async fn get_record() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut addresses = vec![]; - let mut peer_ids = vec![]; - let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0usize)); - - for _ in 0..3 { - let mut libp2p = initialize_libp2p(); - - loop { - if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { - addresses.push(address); - peer_ids.push(*libp2p.local_peer_id()); - break; - } - } - - let counter_copy = std::sync::Arc::clone(&counter); - tokio::spawn(async move { - let mut record_found = false; - - loop { - tokio::select! { - _ = libp2p.select_next_some() => {} - _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { - let store = libp2p.behaviour_mut().kad.store_mut(); - if store.get(&libp2p::kad::record::Key::new(&vec![1, 2, 3, 4])).is_some() && !record_found { - counter_copy.fetch_add(1usize, std::sync::atomic::Ordering::SeqCst); - record_found = true; - } - } - } - } - }); - } - - let mut libp2p = initialize_libp2p(); - let (mut litep2p, mut kad_handle) = initialize_litep2p(); - let address = litep2p.listen_addresses().next().unwrap().clone(); - - for i in 0..addresses.len() { - libp2p.dial(addresses[i].clone()).unwrap(); - let _ = libp2p.behaviour_mut().kad.add_address(&peer_ids[i], addresses[i].clone()); - } - - // publish record on the network - let record = libp2p::kad::Record { - key: libp2p::kad::RecordKey::new(&vec![1, 2, 3, 4]), - value: vec![13, 37, 13, 38], - publisher: None, - expires: None, - }; - libp2p.behaviour_mut().kad.put_record(record, libp2p::kad::Quorum::All).unwrap(); - - #[allow(unused)] - let mut listen_addr = None; - - loop { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::NewListenAddr { address, .. } => { - listen_addr = Some(address); - } - _ => {} - }, - _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { - if counter.load(std::sync::atomic::Ordering::SeqCst) == 3 { - break; - } - } - } - } - - libp2p.dial(address).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p.next_event().await; - } - }); - - let peer_id = *libp2p.local_peer_id(); - - tokio::spawn(async move { - loop { - let _ = libp2p.select_next_some().await; - } - }); - - tokio::time::sleep(std::time::Duration::from_secs(3)).await; - - let listen_addr = listen_addr.unwrap().with(Protocol::P2p(peer_id.into())); - - kad_handle - .add_known_peer( - litep2p::PeerId::from_bytes(&peer_id.to_bytes()).unwrap(), - vec![listen_addr], - ) - .await; - - let _ = kad_handle.get_record(RecordKey::new(&vec![1, 2, 3, 4]), Quorum::All).await; - - loop { - match kad_handle.next().await.unwrap() { - KademliaEvent::GetRecordSuccess { .. } => break, - KademliaEvent::RoutingTableUpdate { .. } => {}, - _ => panic!("invalid event received"), - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut addresses = vec![]; + let mut peer_ids = vec![]; + let counter = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0usize)); + + for _ in 0..3 { + let mut libp2p = initialize_libp2p(); + + loop { + if let SwarmEvent::NewListenAddr { address, .. } = libp2p.select_next_some().await { + addresses.push(address); + peer_ids.push(*libp2p.local_peer_id()); + break; + } + } + + let counter_copy = std::sync::Arc::clone(&counter); + tokio::spawn(async move { + let mut record_found = false; + + loop { + tokio::select! { + _ = libp2p.select_next_some() => {} + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + let store = libp2p.behaviour_mut().kad.store_mut(); + if store.get(&libp2p::kad::record::Key::new(&vec![1, 2, 3, 4])).is_some() && !record_found { + counter_copy.fetch_add(1usize, std::sync::atomic::Ordering::SeqCst); + record_found = true; + } + } + } + } + }); + } + + let mut libp2p = initialize_libp2p(); + let (mut litep2p, mut kad_handle) = initialize_litep2p(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + for i in 0..addresses.len() { + libp2p.dial(addresses[i].clone()).unwrap(); + let _ = libp2p.behaviour_mut().kad.add_address(&peer_ids[i], addresses[i].clone()); + } + + // publish record on the network + let record = libp2p::kad::Record { + key: libp2p::kad::RecordKey::new(&vec![1, 2, 3, 4]), + value: vec![13, 37, 13, 38], + publisher: None, + expires: None, + }; + libp2p.behaviour_mut().kad.put_record(record, libp2p::kad::Quorum::All).unwrap(); + + #[allow(unused)] + let mut listen_addr = None; + + loop { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::NewListenAddr { address, .. } => { + listen_addr = Some(address); + } + _ => {} + }, + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + if counter.load(std::sync::atomic::Ordering::SeqCst) == 3 { + break; + } + } + } + } + + libp2p.dial(address).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p.next_event().await; + } + }); + + let peer_id = *libp2p.local_peer_id(); + + tokio::spawn(async move { + loop { + let _ = libp2p.select_next_some().await; + } + }); + + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + + let listen_addr = listen_addr.unwrap().with(Protocol::P2p(peer_id.into())); + + kad_handle + .add_known_peer( + litep2p::PeerId::from_bytes(&peer_id.to_bytes()).unwrap(), + vec![listen_addr], + ) + .await; + + let _ = kad_handle.get_record(RecordKey::new(&vec![1, 2, 3, 4]), Quorum::All).await; + + loop { + match kad_handle.next().await.unwrap() { + KademliaEvent::GetRecordSuccess { .. } => break, + KademliaEvent::RoutingTableUpdate { .. } => {} + _ => panic!("invalid event received"), + } + } } diff --git a/tests/conformance/rust/ping.rs b/tests/conformance/rust/ping.rs index f02966b8..bf398bdb 100644 --- a/tests/conformance/rust/ping.rs +++ b/tests/conformance/rust/ping.rs @@ -21,108 +21,108 @@ use futures::{Stream, StreamExt}; use libp2p::{ - identity, ping, - swarm::{keep_alive, NetworkBehaviour, SwarmBuilder, SwarmEvent}, - PeerId, Swarm, + identity, ping, + swarm::{keep_alive, NetworkBehaviour, SwarmBuilder, SwarmEvent}, + PeerId, Swarm, }; use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::libp2p::ping::{Config as PingConfig, PingEvent}, - transport::tcp::config::Config as TcpConfig, - Litep2p, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::libp2p::ping::{Config as PingConfig, PingEvent}, + transport::tcp::config::Config as TcpConfig, + Litep2p, }; #[derive(NetworkBehaviour, Default)] struct Behaviour { - keep_alive: keep_alive::Behaviour, - ping: ping::Behaviour, + keep_alive: keep_alive::Behaviour, + ping: ping::Behaviour, } // initialize litep2p with ping support fn initialize_litep2p() -> (Litep2p, Box + Send + Unpin>) { - let keypair = Keypair::generate(); - let (ping_config, ping_event_stream) = PingConfig::default(); - let litep2p = Litep2p::new( - ConfigBuilder::new() - .with_keypair(keypair) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config) - .build(), - ) - .unwrap(); - - (litep2p, ping_event_stream) + let keypair = Keypair::generate(); + let (ping_config, ping_event_stream) = PingConfig::default(); + let litep2p = Litep2p::new( + ConfigBuilder::new() + .with_keypair(keypair) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config) + .build(), + ) + .unwrap(); + + (litep2p, ping_event_stream) } fn initialize_libp2p() -> Swarm { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); - tracing::debug!("Local peer id: {local_peer_id:?}"); + tracing::debug!("Local peer id: {local_peer_id:?}"); - let transport = libp2p::tokio_development_transport(local_key).unwrap(); - let mut swarm = - SwarmBuilder::with_tokio_executor(transport, Behaviour::default(), local_peer_id).build(); + let transport = libp2p::tokio_development_transport(local_key).unwrap(); + let mut swarm = + SwarmBuilder::with_tokio_executor(transport, Behaviour::default(), local_peer_id).build(); - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - swarm + swarm } #[tokio::test] async fn libp2p_dials() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut libp2p = initialize_libp2p(); - let (mut litep2p, mut ping_event_stream) = initialize_litep2p(); - let address = litep2p.listen_addresses().next().unwrap().clone(); - - libp2p.dial(address).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p.next_event().await; - } - }); - - let mut libp2p_done = false; - let mut litep2p_done = false; - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - SwarmEvent::Behaviour(BehaviourEvent::Ping(_)) => { - libp2p_done = true; - - if libp2p_done && litep2p_done { - break - } - } - _ => {} - } - } - _event = ping_event_stream.next() => { - litep2p_done = true; - - if libp2p_done && litep2p_done { - break - } - } - _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { - panic!("failed to receive ping in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut libp2p = initialize_libp2p(); + let (mut litep2p, mut ping_event_stream) = initialize_litep2p(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + libp2p.dial(address).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p.next_event().await; + } + }); + + let mut libp2p_done = false; + let mut litep2p_done = false; + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + SwarmEvent::Behaviour(BehaviourEvent::Ping(_)) => { + libp2p_done = true; + + if libp2p_done && litep2p_done { + break + } + } + _ => {} + } + } + _event = ping_event_stream.next() => { + litep2p_done = true; + + if libp2p_done && litep2p_done { + break + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { + panic!("failed to receive ping in time"); + } + } + } } #[tokio::test] diff --git a/tests/conformance/rust/quic_ping.rs b/tests/conformance/rust/quic_ping.rs index 4d0bef4e..f7fee590 100644 --- a/tests/conformance/rust/quic_ping.rs +++ b/tests/conformance/rust/quic_ping.rs @@ -21,122 +21,124 @@ use futures::{future::Either, Stream, StreamExt}; use libp2p::{ - core::{muxing::StreamMuxerBox, transport::OrTransport}, - identity, ping, quic, - swarm::{keep_alive, NetworkBehaviour, SwarmBuilder, SwarmEvent}, - PeerId, Swarm, Transport, + core::{muxing::StreamMuxerBox, transport::OrTransport}, + identity, ping, quic, + swarm::{keep_alive, NetworkBehaviour, SwarmBuilder, SwarmEvent}, + PeerId, Swarm, Transport, }; use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::libp2p::ping::{Config as PingConfig, PingEvent}, - transport::quic::config::Config as QuicConfig, - Litep2p, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::libp2p::ping::{Config as PingConfig, PingEvent}, + transport::quic::config::Config as QuicConfig, + Litep2p, }; #[derive(NetworkBehaviour, Default)] struct Behaviour { - keep_alive: keep_alive::Behaviour, - ping: ping::Behaviour, + keep_alive: keep_alive::Behaviour, + ping: ping::Behaviour, } // initialize litep2p with ping support fn initialize_litep2p() -> (Litep2p, Box + Send + Unpin>) { - let keypair = Keypair::generate(); - let (ping_config, ping_event_stream) = PingConfig::default(); - let litep2p = Litep2p::new( - ConfigBuilder::new() - .with_keypair(keypair) - .with_quic(QuicConfig { - listen_addresses: vec!["/ip4/127.0.0.1/udp/8888/quic-v1".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config) - .build(), - ) - .unwrap(); - - (litep2p, ping_event_stream) + let keypair = Keypair::generate(); + let (ping_config, ping_event_stream) = PingConfig::default(); + let litep2p = Litep2p::new( + ConfigBuilder::new() + .with_keypair(keypair) + .with_quic(QuicConfig { + listen_addresses: vec!["/ip4/127.0.0.1/udp/8888/quic-v1".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config) + .build(), + ) + .unwrap(); + + (litep2p, ping_event_stream) } fn initialize_libp2p() -> Swarm { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); - tracing::debug!("Local peer id: {local_peer_id:?}"); + tracing::debug!("Local peer id: {local_peer_id:?}"); - let tcp_transport = libp2p::tokio_development_transport(local_key.clone()).unwrap(); + let tcp_transport = libp2p::tokio_development_transport(local_key.clone()).unwrap(); - let quic_transport = quic::tokio::Transport::new(quic::Config::new(&local_key)); - let transport = OrTransport::new(quic_transport, tcp_transport) - .map(|either_output, _| match either_output { - Either::Left((peer_id, muxer)) => (peer_id, StreamMuxerBox::new(muxer)), - Either::Right((peer_id, muxer)) => (peer_id, StreamMuxerBox::new(muxer)), - }) - .boxed(); + let quic_transport = quic::tokio::Transport::new(quic::Config::new(&local_key)); + let transport = OrTransport::new(quic_transport, tcp_transport) + .map(|either_output, _| match either_output { + Either::Left((peer_id, muxer)) => (peer_id, StreamMuxerBox::new(muxer)), + Either::Right((peer_id, muxer)) => (peer_id, StreamMuxerBox::new(muxer)), + }) + .boxed(); - let mut swarm = - SwarmBuilder::with_tokio_executor(transport, Behaviour::default(), local_peer_id).build(); + let mut swarm = + SwarmBuilder::with_tokio_executor(transport, Behaviour::default(), local_peer_id).build(); - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - swarm.listen_on("/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap()).unwrap(); + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + swarm.listen_on("/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap()).unwrap(); - swarm + swarm } #[tokio::test] async fn libp2p_dials() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let mut libp2p = initialize_libp2p(); - let (mut litep2p, mut ping_event_stream) = initialize_litep2p(); - - let address: multiaddr::Multiaddr = - format!("/ip4/127.0.0.1/udp/8888/quic-v1/p2p/{}", *litep2p.local_peer_id()) - .parse() - .unwrap(); - libp2p.dial(address).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p.next_event().await; - } - }); - - let mut libp2p_done = false; - let mut litep2p_done = false; - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - SwarmEvent::Behaviour(BehaviourEvent::Ping(_)) => { - libp2p_done = true; - - if libp2p_done && litep2p_done { - break - } - } - _ => {} - } - } - _event = ping_event_stream.next() => { - litep2p_done = true; - - if libp2p_done && litep2p_done { - break - } - } - _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { - panic!("failed to receive ping in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let mut libp2p = initialize_libp2p(); + let (mut litep2p, mut ping_event_stream) = initialize_litep2p(); + + let address: multiaddr::Multiaddr = format!( + "/ip4/127.0.0.1/udp/8888/quic-v1/p2p/{}", + *litep2p.local_peer_id() + ) + .parse() + .unwrap(); + libp2p.dial(address).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p.next_event().await; + } + }); + + let mut libp2p_done = false; + let mut litep2p_done = false; + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + SwarmEvent::Behaviour(BehaviourEvent::Ping(_)) => { + libp2p_done = true; + + if libp2p_done && litep2p_done { + break + } + } + _ => {} + } + } + _event = ping_event_stream.next() => { + litep2p_done = true; + + if libp2p_done && litep2p_done { + break + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => { + panic!("failed to receive ping in time"); + } + } + } } #[tokio::test] diff --git a/tests/conformance/substrate/connection.rs b/tests/conformance/substrate/connection.rs index 900ed6be..fd8c28cf 100644 --- a/tests/conformance/substrate/connection.rs +++ b/tests/conformance/substrate/connection.rs @@ -19,132 +19,132 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::notification::{handle::NotificationHandle, types::Config as NotificationConfig}, - transport::tcp::config::Config as TcpConfig, - types::protocol::ProtocolName as Litep2pProtocol, - Litep2p, Litep2pEvent, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::notification::{handle::NotificationHandle, types::Config as NotificationConfig}, + transport::tcp::config::Config as TcpConfig, + types::protocol::ProtocolName as Litep2pProtocol, + Litep2p, Litep2pEvent, }; use futures::StreamExt; use libp2p::{ - identity, - swarm::{SwarmBuilder, SwarmEvent}, - PeerId, Swarm, + identity, + swarm::{SwarmBuilder, SwarmEvent}, + PeerId, Swarm, }; use sc_network::{ - peer_store::{PeerStore, PeerStoreHandle}, - protocol::notifications::behaviour::{Notifications, ProtocolConfig}, - protocol_controller::{ProtoSetConfig, ProtocolController, SetId}, - types::ProtocolName, + peer_store::{PeerStore, PeerStoreHandle}, + protocol::notifications::behaviour::{Notifications, ProtocolConfig}, + protocol_controller::{ProtoSetConfig, ProtocolController, SetId}, + types::ProtocolName, }; use sc_utils::mpsc::tracing_unbounded; use std::collections::HashSet; fn initialize_libp2p(in_peers: u32, out_peers: u32) -> (Swarm, PeerStoreHandle) { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); - let peer_store = PeerStore::new(vec![]); - - let (tx, rx) = tracing_unbounded("channel", 10_000); - let proto_set_config = ProtoSetConfig { - in_peers, - out_peers, - reserved_nodes: HashSet::new(), - reserved_only: false, - }; - - let (handle, controller) = ProtocolController::new( - SetId::from(0usize), - proto_set_config, - tx.clone(), - Box::new(peer_store.handle()), - ); - let peer_store_handle = peer_store.handle(); - tokio::spawn(controller.run()); - tokio::spawn(peer_store.run()); - - let proto_config = ProtocolConfig { - name: ProtocolName::from("/notif/1"), - fallback_names: vec![], - handshake: vec![1, 3, 3, 7], - max_notification_size: 1000u64, - }; - let behaviour = Notifications::new(vec![handle], rx, vec![proto_config].into_iter()); - let transport = libp2p::tokio_development_transport(local_key).unwrap(); - let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); - - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - - (swarm, peer_store_handle) + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); + let peer_store = PeerStore::new(vec![]); + + let (tx, rx) = tracing_unbounded("channel", 10_000); + let proto_set_config = ProtoSetConfig { + in_peers, + out_peers, + reserved_nodes: HashSet::new(), + reserved_only: false, + }; + + let (handle, controller) = ProtocolController::new( + SetId::from(0usize), + proto_set_config, + tx.clone(), + Box::new(peer_store.handle()), + ); + let peer_store_handle = peer_store.handle(); + tokio::spawn(controller.run()); + tokio::spawn(peer_store.run()); + + let proto_config = ProtocolConfig { + name: ProtocolName::from("/notif/1"), + fallback_names: vec![], + handshake: vec![1, 3, 3, 7], + max_notification_size: 1000u64, + }; + let behaviour = Notifications::new(vec![handle], rx, vec![proto_config].into_iter()); + let transport = libp2p::tokio_development_transport(local_key).unwrap(); + let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); + + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + + (swarm, peer_store_handle) } async fn initialize_litep2p() -> (Litep2p, NotificationHandle) { - let (notif_config1, handle) = NotificationConfig::new( - Litep2pProtocol::from("/notif/1"), - 1024usize, - vec![1, 3, 3, 8], - Vec::new(), - ); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_address: "/ip6/::1/tcp/0".parse().unwrap(), - yamux_config: Default::default(), - }) - .with_notification_protocol(notif_config1) - .build(); - let litep2p = Litep2p::new(config1).await.unwrap(); - - (litep2p, handle) + let (notif_config1, handle) = NotificationConfig::new( + Litep2pProtocol::from("/notif/1"), + 1024usize, + vec![1, 3, 3, 8], + Vec::new(), + ); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_address: "/ip6/::1/tcp/0".parse().unwrap(), + yamux_config: Default::default(), + }) + .with_notification_protocol(notif_config1) + .build(); + let litep2p = Litep2p::new(config1).await.unwrap(); + + (litep2p, handle) } #[tokio::test] async fn substrate_keep_alive_timeout() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, _peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - let mut libp2p_connection_open = false; - let mut libp2p_connection_closed = false; - let mut litep2p_connection_open = false; - let mut litep2p_connection_closed = false; - - while !libp2p_connection_open || - !libp2p_connection_closed || - !litep2p_connection_open || - !litep2p_connection_closed - { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::ConnectionEstablished { .. } => { - libp2p_connection_open = true; - } - SwarmEvent::ConnectionClosed { .. } => { - libp2p_connection_closed = true; - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p_connection_open = true; - } - Litep2pEvent::ConnectionClosed { .. } => { - litep2p_connection_closed = true; - } - _ => {} - }, - event = handle.next() => match event.unwrap() { - event => tracing::debug!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, _peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + let mut libp2p_connection_open = false; + let mut libp2p_connection_closed = false; + let mut litep2p_connection_open = false; + let mut litep2p_connection_closed = false; + + while !libp2p_connection_open + || !libp2p_connection_closed + || !litep2p_connection_open + || !litep2p_connection_closed + { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::ConnectionEstablished { .. } => { + libp2p_connection_open = true; + } + SwarmEvent::ConnectionClosed { .. } => { + libp2p_connection_closed = true; + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p_connection_open = true; + } + Litep2pEvent::ConnectionClosed { .. } => { + litep2p_connection_closed = true; + } + _ => {} + }, + event = handle.next() => match event.unwrap() { + event => tracing::debug!("unhanled notification event: {event:?}"), + } + } + } } diff --git a/tests/conformance/substrate/notifications.rs b/tests/conformance/substrate/notifications.rs index 10b7dfd0..9d5df646 100644 --- a/tests/conformance/substrate/notifications.rs +++ b/tests/conformance/substrate/notifications.rs @@ -19,325 +19,325 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::notification::{ - handle::NotificationHandle, - types::{ - Config as NotificationConfig, NotificationError, NotificationEvent, ValidationResult, - }, - }, - transport::tcp::config::Config as TcpConfig, - types::protocol::ProtocolName as Litep2pProtocol, - Litep2p, Litep2pEvent, PeerId as Litep2pPeerId, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::notification::{ + handle::NotificationHandle, + types::{ + Config as NotificationConfig, NotificationError, NotificationEvent, ValidationResult, + }, + }, + transport::tcp::config::Config as TcpConfig, + types::protocol::ProtocolName as Litep2pProtocol, + Litep2p, Litep2pEvent, PeerId as Litep2pPeerId, }; use futures::StreamExt; use libp2p::{ - identity, - swarm::{SwarmBuilder, SwarmEvent}, - PeerId, Swarm, + identity, + swarm::{SwarmBuilder, SwarmEvent}, + PeerId, Swarm, }; use sc_network::{ - peer_store::{PeerStore, PeerStoreHandle, PeerStoreProvider}, - protocol::notifications::behaviour::{Notifications, NotificationsOut, ProtocolConfig}, - protocol_controller::{ProtoSetConfig, ProtocolController, SetId}, - types::ProtocolName, - ReputationChange, + peer_store::{PeerStore, PeerStoreHandle, PeerStoreProvider}, + protocol::notifications::behaviour::{Notifications, NotificationsOut, ProtocolConfig}, + protocol_controller::{ProtoSetConfig, ProtocolController, SetId}, + types::ProtocolName, + ReputationChange, }; use sc_utils::mpsc::tracing_unbounded; use std::collections::HashSet; fn initialize_libp2p(in_peers: u32, out_peers: u32) -> (Swarm, PeerStoreHandle) { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); - let peer_store = PeerStore::new(vec![]); - - let (tx, rx) = tracing_unbounded("channel", 10_000); - let proto_set_config = ProtoSetConfig { - in_peers, - out_peers, - reserved_nodes: HashSet::new(), - reserved_only: false, - }; - - let (handle, controller) = ProtocolController::new( - SetId::from(0usize), - proto_set_config, - tx.clone(), - Box::new(peer_store.handle()), - ); - let peer_store_handle = peer_store.handle(); - tokio::spawn(controller.run()); - tokio::spawn(peer_store.run()); - - let proto_config = ProtocolConfig { - name: ProtocolName::from("/notif/1"), - fallback_names: vec![], - handshake: vec![1, 3, 3, 7], - max_notification_size: 1000u64, - }; - let behaviour = Notifications::new(vec![handle], rx, vec![proto_config].into_iter()); - let transport = libp2p::tokio_development_transport(local_key).unwrap(); - let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); - - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - - (swarm, peer_store_handle) + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); + let peer_store = PeerStore::new(vec![]); + + let (tx, rx) = tracing_unbounded("channel", 10_000); + let proto_set_config = ProtoSetConfig { + in_peers, + out_peers, + reserved_nodes: HashSet::new(), + reserved_only: false, + }; + + let (handle, controller) = ProtocolController::new( + SetId::from(0usize), + proto_set_config, + tx.clone(), + Box::new(peer_store.handle()), + ); + let peer_store_handle = peer_store.handle(); + tokio::spawn(controller.run()); + tokio::spawn(peer_store.run()); + + let proto_config = ProtocolConfig { + name: ProtocolName::from("/notif/1"), + fallback_names: vec![], + handshake: vec![1, 3, 3, 7], + max_notification_size: 1000u64, + }; + let behaviour = Notifications::new(vec![handle], rx, vec![proto_config].into_iter()); + let transport = libp2p::tokio_development_transport(local_key).unwrap(); + let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); + + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + + (swarm, peer_store_handle) } async fn initialize_litep2p() -> (Litep2p, NotificationHandle) { - let (notif_config1, handle) = NotificationConfig::new( - Litep2pProtocol::from("/notif/1"), - 1024usize, - vec![1, 3, 3, 8], - Vec::new(), - ); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_address: "/ip6/::1/tcp/0".parse().unwrap(), - yamux_config: Default::default(), - }) - .with_notification_protocol(notif_config1) - .build(); - let litep2p = Litep2p::new(config1).await.unwrap(); - - (litep2p, handle) + let (notif_config1, handle) = NotificationConfig::new( + Litep2pProtocol::from("/notif/1"), + 1024usize, + vec![1, 3, 3, 8], + Vec::new(), + ); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_address: "/ip6/::1/tcp/0".parse().unwrap(), + yamux_config: Default::default(), + }) + .with_notification_protocol(notif_config1) + .build(); + let litep2p = Litep2p::new(config1).await.unwrap(); + + (litep2p, handle) } #[tokio::test] async fn substrate_open_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - let litep2p_peer = *litep2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - let mut libp2p_ready = false; - let mut litep2p_ready = false; - let mut litep2p_3333_seen = false; - let mut litep2p_4444_seen = false; - let mut libp2p_1111_seen = false; - let mut libp2p_2222_seen = false; - - while !libp2p_ready || - !litep2p_ready || - !litep2p_3333_seen || - !litep2p_4444_seen || - !libp2p_1111_seen || - !libp2p_2222_seen - { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::ConnectionEstablished { .. } => { - peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); - } - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { - peer_id, set_id, negotiated_fallback, received_handshake, notifications_sink, inbound, - }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - assert_eq!(received_handshake, vec![1, 3, 3, 8]); - assert!(negotiated_fallback.is_none()); - assert!(!inbound); - - notifications_sink.reserve_notification().await.unwrap().send(vec![3, 3, 3, 3]).unwrap(); - notifications_sink.send_sync_notification(vec![4, 4, 4, 4]); - - libp2p_ready = true; - } - SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, message }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - - if message == vec![1, 1, 1, 1] { - libp2p_1111_seen = true; - } else if message == vec![2, 2, 2, 2] { - libp2p_2222_seen = true; - } - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event { - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_validation_result(peer, ValidationResult::Accept).await; - litep2p_ready = true; - } - NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); - handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); - } - NotificationEvent::NotificationReceived { peer, notification } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - - if notification == vec![3, 3, 3, 3] { - litep2p_3333_seen = true; - } else if notification == vec![4, 4, 4, 4] { - litep2p_4444_seen = true; - } - } - event => tracing::error!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + let litep2p_peer = *litep2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + let mut libp2p_ready = false; + let mut litep2p_ready = false; + let mut litep2p_3333_seen = false; + let mut litep2p_4444_seen = false; + let mut libp2p_1111_seen = false; + let mut libp2p_2222_seen = false; + + while !libp2p_ready + || !litep2p_ready + || !litep2p_3333_seen + || !litep2p_4444_seen + || !libp2p_1111_seen + || !libp2p_2222_seen + { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::ConnectionEstablished { .. } => { + peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); + } + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { + peer_id, set_id, negotiated_fallback, received_handshake, notifications_sink, inbound, + }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + assert_eq!(received_handshake, vec![1, 3, 3, 8]); + assert!(negotiated_fallback.is_none()); + assert!(!inbound); + + notifications_sink.reserve_notification().await.unwrap().send(vec![3, 3, 3, 3]).unwrap(); + notifications_sink.send_sync_notification(vec![4, 4, 4, 4]); + + libp2p_ready = true; + } + SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, message }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + + if message == vec![1, 1, 1, 1] { + libp2p_1111_seen = true; + } else if message == vec![2, 2, 2, 2] { + libp2p_2222_seen = true; + } + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event { + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_validation_result(peer, ValidationResult::Accept).await; + litep2p_ready = true; + } + NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); + handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); + } + NotificationEvent::NotificationReceived { peer, notification } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + + if notification == vec![3, 3, 3, 3] { + litep2p_3333_seen = true; + } else if notification == vec![4, 4, 4, 4] { + litep2p_4444_seen = true; + } + } + event => tracing::error!("unhanled notification event: {event:?}"), + } + } + } } #[tokio::test] async fn litep2p_open_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, _peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - let litep2p_peer = *litep2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - let mut libp2p_ready = false; - let mut litep2p_ready = false; - let mut litep2p_3333_seen = false; - let mut litep2p_4444_seen = false; - let mut libp2p_1111_seen = false; - let mut libp2p_2222_seen = false; - - while !libp2p_ready || - !litep2p_ready || - !litep2p_3333_seen || - !litep2p_4444_seen || - !libp2p_1111_seen || - !libp2p_2222_seen - { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { - peer_id, set_id, negotiated_fallback, received_handshake, notifications_sink, inbound, - }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - assert_eq!(received_handshake, vec![1, 3, 3, 8]); - assert!(negotiated_fallback.is_none()); - assert!(inbound); - - notifications_sink.reserve_notification().await.unwrap().send(vec![3, 3, 3, 3]).unwrap(); - notifications_sink.send_sync_notification(vec![4, 4, 4, 4]); - - libp2p_ready = true; - } - SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, message }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - - if message == vec![1, 1, 1, 1] { - libp2p_1111_seen = true; - } else if message == vec![2, 2, 2, 2] { - libp2p_2222_seen = true; - } - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { peer, .. } => { - // TODO: zzz - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - handle.open_substream(peer).await.unwrap(); - } - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_validation_result(peer, ValidationResult::Accept).await; - litep2p_ready = true; - } - NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); - handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); - } - NotificationEvent::NotificationReceived { peer, notification } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - - if notification == vec![3, 3, 3, 3] { - litep2p_3333_seen = true; - } else if notification == vec![4, 4, 4, 4] { - litep2p_4444_seen = true; - } - } - event => tracing::error!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, _peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + let litep2p_peer = *litep2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + let mut libp2p_ready = false; + let mut litep2p_ready = false; + let mut litep2p_3333_seen = false; + let mut litep2p_4444_seen = false; + let mut libp2p_1111_seen = false; + let mut libp2p_2222_seen = false; + + while !libp2p_ready + || !litep2p_ready + || !litep2p_3333_seen + || !litep2p_4444_seen + || !libp2p_1111_seen + || !libp2p_2222_seen + { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { + peer_id, set_id, negotiated_fallback, received_handshake, notifications_sink, inbound, + }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + assert_eq!(received_handshake, vec![1, 3, 3, 8]); + assert!(negotiated_fallback.is_none()); + assert!(inbound); + + notifications_sink.reserve_notification().await.unwrap().send(vec![3, 3, 3, 3]).unwrap(); + notifications_sink.send_sync_notification(vec![4, 4, 4, 4]); + + libp2p_ready = true; + } + SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, message }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + + if message == vec![1, 1, 1, 1] { + libp2p_1111_seen = true; + } else if message == vec![2, 2, 2, 2] { + libp2p_2222_seen = true; + } + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { peer, .. } => { + // TODO: zzz + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + handle.open_substream(peer).await.unwrap(); + } + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_validation_result(peer, ValidationResult::Accept).await; + litep2p_ready = true; + } + NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); + handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); + } + NotificationEvent::NotificationReceived { peer, notification } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + + if notification == vec![3, 3, 3, 3] { + litep2p_3333_seen = true; + } else if notification == vec![4, 4, 4, 4] { + litep2p_4444_seen = true; + } + } + event => tracing::error!("unhanled notification event: {event:?}"), + } + } + } } #[tokio::test] async fn substrate_reject_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - // set inbound peer to count 0 so `ProtocolController` will reject the peer - let (mut libp2p, _peer_store_handle) = initialize_libp2p(0u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => match event { - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { peer, .. } => { - // TODO: zzz - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - handle.open_substream(peer).await.unwrap(); - } - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::NotificationStreamOpenFailure { peer, error } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(error, NotificationError::Rejected); - break; - } - NotificationEvent::NotificationStreamClosed { .. } => break, - event => tracing::error!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + // set inbound peer to count 0 so `ProtocolController` will reject the peer + let (mut libp2p, _peer_store_handle) = initialize_libp2p(0u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => match event { + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { peer, .. } => { + // TODO: zzz + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + handle.open_substream(peer).await.unwrap(); + } + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::NotificationStreamOpenFailure { peer, error } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(error, NotificationError::Rejected); + break; + } + NotificationEvent::NotificationStreamClosed { .. } => break, + event => tracing::error!("unhanled notification event: {event:?}"), + } + } + } } // NOTE: there is a known bug in Substrate where `ProtocolController` opens a connection to the peer @@ -350,45 +350,45 @@ async fn substrate_reject_substream() { #[tokio::test] #[ignore] async fn litep2p_reject_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - let litep2p_peer = *litep2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::ConnectionEstablished { .. } => { - peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event { - Some(Litep2pEvent::ConnectionClosed { .. }) => break, - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_validation_result(peer, ValidationResult::Reject).await; - - tracing::info!("reject substream"); - } - event => tracing::error!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + let litep2p_peer = *litep2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::ConnectionEstablished { .. } => { + peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event { + Some(Litep2pEvent::ConnectionClosed { .. }) => break, + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_validation_result(peer, ValidationResult::Reject).await; + + tracing::info!("reject substream"); + } + event => tracing::error!("unhanled notification event: {event:?}"), + } + } + } } // NOTE: there is a known bug in Substrate where `ProtocolController` opens a connection to the peer @@ -402,88 +402,88 @@ async fn litep2p_reject_substream() { #[tokio::test] #[ignore] async fn substrate_close_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - let litep2p_peer = *litep2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - let mut libp2p_notification_count = 0; - - loop { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::ConnectionEstablished { .. } => { - peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); - } - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { - peer_id, set_id, negotiated_fallback, received_handshake, notifications_sink, inbound, - }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - assert_eq!(received_handshake, vec![1, 3, 3, 8]); - assert!(negotiated_fallback.is_none()); - assert!(!inbound); - - notifications_sink.reserve_notification().await.unwrap().send(vec![3, 3, 3, 3]).unwrap(); - notifications_sink.send_sync_notification(vec![4, 4, 4, 4]); - } - SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, .. }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - - if libp2p_notification_count == 0 { - libp2p_notification_count += 1; - } else { - libp2p_notification_count += 1; - libp2p.behaviour_mut().disconnect_peer(&peer_id, set_id); - } - } - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolClosed { .. }) => { - handle.send_sync_notification( - Litep2pPeerId::from_bytes(&libp2p_peer.to_bytes()).unwrap(), - vec![1 ,2 , 3, 4] - ).unwrap(); - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event.unwrap() { - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_validation_result(peer, ValidationResult::Accept).await; - } - NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); - handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); - } - NotificationEvent::NotificationReceived { peer, .. } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - } - NotificationEvent::NotificationStreamClosed { peer } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - break; - } - event => tracing::error!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + let litep2p_peer = *litep2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + let mut libp2p_notification_count = 0; + + loop { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::ConnectionEstablished { .. } => { + peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); + } + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { + peer_id, set_id, negotiated_fallback, received_handshake, notifications_sink, inbound, + }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + assert_eq!(received_handshake, vec![1, 3, 3, 8]); + assert!(negotiated_fallback.is_none()); + assert!(!inbound); + + notifications_sink.reserve_notification().await.unwrap().send(vec![3, 3, 3, 3]).unwrap(); + notifications_sink.send_sync_notification(vec![4, 4, 4, 4]); + } + SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, .. }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + + if libp2p_notification_count == 0 { + libp2p_notification_count += 1; + } else { + libp2p_notification_count += 1; + libp2p.behaviour_mut().disconnect_peer(&peer_id, set_id); + } + } + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolClosed { .. }) => { + handle.send_sync_notification( + Litep2pPeerId::from_bytes(&libp2p_peer.to_bytes()).unwrap(), + vec![1 ,2 , 3, 4] + ).unwrap(); + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event.unwrap() { + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_validation_result(peer, ValidationResult::Accept).await; + } + NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); + handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); + } + NotificationEvent::NotificationReceived { peer, .. } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + } + NotificationEvent::NotificationStreamClosed { peer } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + break; + } + event => tracing::error!("unhanled notification event: {event:?}"), + } + } + } } // NOTE: Substrate doesn't consider the inbound substream closed as error which would disconnect @@ -494,143 +494,143 @@ async fn substrate_close_substream() { // only when the protocol tries to write something to the substream. #[tokio::test] async fn litep2p_close_substream() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - let litep2p_peer = *litep2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - let mut notif_count = 0; - let mut peerse = None; - - loop { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::ConnectionEstablished { .. } => { - peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); - } - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { - peer_id, set_id, negotiated_fallback, received_handshake, inbound, .. - }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - assert_eq!(received_handshake, vec![1, 3, 3, 8]); - assert!(negotiated_fallback.is_none()); - assert!(!inbound); - } - SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, .. }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - - notif_count += 1; - if notif_count == 2 { - handle.close_substream(peerse.unwrap()).await; - } - } - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolClosed { .. }) => { - break; - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event.unwrap() { - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_validation_result(peer, ValidationResult::Accept).await; - } - NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); - handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); - peerse = Some(peer); - } - NotificationEvent::NotificationReceived { peer, .. } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - notif_count += 1; - } - NotificationEvent::NotificationStreamClosed { .. } => { - break; - } - event => tracing::error!("unhanled notification event: {event:?}"), - }, - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + let litep2p_peer = *litep2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + let mut notif_count = 0; + let mut peerse = None; + + loop { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::ConnectionEstablished { .. } => { + peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); + } + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { + peer_id, set_id, negotiated_fallback, received_handshake, inbound, .. + }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + assert_eq!(received_handshake, vec![1, 3, 3, 8]); + assert!(negotiated_fallback.is_none()); + assert!(!inbound); + } + SwarmEvent::Behaviour(NotificationsOut::Notification { peer_id, set_id, .. }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + + notif_count += 1; + if notif_count == 2 { + handle.close_substream(peerse.unwrap()).await; + } + } + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolClosed { .. }) => { + break; + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event.unwrap() { + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_validation_result(peer, ValidationResult::Accept).await; + } + NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_sync_notification(peer, vec![1, 1, 1, 1]).unwrap(); + handle.send_async_notification(peer, vec![2, 2, 2, 2]).await.unwrap(); + peerse = Some(peer); + } + NotificationEvent::NotificationReceived { peer, .. } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + notif_count += 1; + } + NotificationEvent::NotificationStreamClosed { .. } => { + break; + } + event => tracing::error!("unhanled notification event: {event:?}"), + }, + } + } } #[tokio::test] async fn both_nodes_open_substreams() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); - let (mut litep2p, mut handle) = initialize_litep2p().await; - - let libp2p_peer = *libp2p.local_peer_id(); - let litep2p_peer = *litep2p.local_peer_id(); - - let address = litep2p.listen_addresses().next().unwrap().clone(); - libp2p.dial(address).unwrap(); - - let mut libp2p_ready = false; - let mut litep2p_ready = false; - - while !litep2p_ready || !libp2p_ready { - tokio::select! { - event = libp2p.select_next_some() => match event { - SwarmEvent::ConnectionEstablished { .. } => { - peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - handle.open_substream(Litep2pPeerId::from_bytes(&libp2p_peer.to_bytes()).unwrap()).await.unwrap(); - } - SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { - peer_id, set_id, negotiated_fallback, received_handshake, .. - }) => { - assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); - assert_eq!(set_id, SetId::from(0usize)); - assert_eq!(received_handshake, vec![1, 3, 3, 8]); - assert!(negotiated_fallback.is_none()); - - libp2p_ready = true; - } - event => tracing::info!("unhanled libp2p event: {event:?}"), - }, - event = litep2p.next_event() => match event.unwrap() { - event => tracing::info!("unhanled litep2p event: {event:?}"), - }, - event = handle.next() => match event.unwrap() { - NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - handle.send_validation_result(peer, ValidationResult::Accept).await; - litep2p_ready = true; - } - NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { - assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(handshake, vec![1, 3, 3, 7]); - - litep2p_ready = true; - } - event => tracing::error!("unhanled notification event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, mut peer_store_handle) = initialize_libp2p(1u32, 1u32); + let (mut litep2p, mut handle) = initialize_litep2p().await; + + let libp2p_peer = *libp2p.local_peer_id(); + let litep2p_peer = *litep2p.local_peer_id(); + + let address = litep2p.listen_addresses().next().unwrap().clone(); + libp2p.dial(address).unwrap(); + + let mut libp2p_ready = false; + let mut litep2p_ready = false; + + while !litep2p_ready || !libp2p_ready { + tokio::select! { + event = libp2p.select_next_some() => match event { + SwarmEvent::ConnectionEstablished { .. } => { + peer_store_handle.add_known_peer(PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap()); + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + handle.open_substream(Litep2pPeerId::from_bytes(&libp2p_peer.to_bytes()).unwrap()).await.unwrap(); + } + SwarmEvent::Behaviour(NotificationsOut::CustomProtocolOpen { + peer_id, set_id, negotiated_fallback, received_handshake, .. + }) => { + assert_eq!(peer_id.to_bytes(), litep2p_peer.to_bytes()); + assert_eq!(set_id, SetId::from(0usize)); + assert_eq!(received_handshake, vec![1, 3, 3, 8]); + assert!(negotiated_fallback.is_none()); + + libp2p_ready = true; + } + event => tracing::info!("unhanled libp2p event: {event:?}"), + }, + event = litep2p.next_event() => match event.unwrap() { + event => tracing::info!("unhanled litep2p event: {event:?}"), + }, + event = handle.next() => match event.unwrap() { + NotificationEvent::ValidateSubstream { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + handle.send_validation_result(peer, ValidationResult::Accept).await; + litep2p_ready = true; + } + NotificationEvent::NotificationStreamOpened { protocol, peer, handshake } => { + assert_eq!(protocol, Litep2pProtocol::from("/notif/1")); + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(handshake, vec![1, 3, 3, 7]); + + litep2p_ready = true; + } + event => tracing::error!("unhanled notification event: {event:?}"), + } + } + } } diff --git a/tests/conformance/substrate/request_response.rs b/tests/conformance/substrate/request_response.rs index a0d12f94..437dd49d 100644 --- a/tests/conformance/substrate/request_response.rs +++ b/tests/conformance/substrate/request_response.rs @@ -20,412 +20,415 @@ use futures::{channel::oneshot, stream::FuturesUnordered, StreamExt}; use libp2p::{ - identity, - swarm::{SwarmBuilder, SwarmEvent}, - PeerId, Swarm, + identity, + swarm::{SwarmBuilder, SwarmEvent}, + PeerId, Swarm, }; use litep2p::{ - config::ConfigBuilder, - protocol::request_response::{ - RequestResponseConfig, RequestResponseError, RequestResponseEvent, RequestResponseHandle, - }, - transport::tcp::config::Config as TcpConfig, - Litep2p, Litep2pEvent, + config::ConfigBuilder, + protocol::request_response::{ + RequestResponseConfig, RequestResponseError, RequestResponseEvent, RequestResponseHandle, + }, + transport::tcp::config::Config as TcpConfig, + Litep2p, Litep2pEvent, }; use sc_network::{ - peer_store::PeerStore, - request_responses::{ - IncomingRequest, OutgoingResponse, ProtocolConfig, RequestFailure, - RequestResponsesBehaviour, - }, - types::ProtocolName, - IfDisconnected, OutboundFailure, + peer_store::PeerStore, + request_responses::{ + IncomingRequest, OutgoingResponse, ProtocolConfig, RequestFailure, + RequestResponsesBehaviour, + }, + types::ProtocolName, + IfDisconnected, OutboundFailure, }; -fn initialize_libp2p( -) -> (Swarm, PeerStore, async_channel::Receiver) { - let local_key = identity::Keypair::generate_ed25519(); - let local_peer_id = PeerId::from(local_key.public()); - let peer_store = PeerStore::new(vec![]); - let peer_store_handle = Box::new(peer_store.handle()); - - let (tx, rx) = async_channel::bounded(64); - let configs = vec![ProtocolConfig { - name: ProtocolName::from("/request/1"), - fallback_names: Vec::new(), - max_request_size: 256, - max_response_size: 2 * 256, - request_timeout: std::time::Duration::from_secs(10), - inbound_queue: Some(tx), - }]; - - let behaviour = RequestResponsesBehaviour::new(configs.into_iter(), peer_store_handle).unwrap(); - - let transport = libp2p::tokio_development_transport(local_key).unwrap(); - let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); - - swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); - - (swarm, peer_store, rx) +fn initialize_libp2p() -> ( + Swarm, + PeerStore, + async_channel::Receiver, +) { + let local_key = identity::Keypair::generate_ed25519(); + let local_peer_id = PeerId::from(local_key.public()); + let peer_store = PeerStore::new(vec![]); + let peer_store_handle = Box::new(peer_store.handle()); + + let (tx, rx) = async_channel::bounded(64); + let configs = vec![ProtocolConfig { + name: ProtocolName::from("/request/1"), + fallback_names: Vec::new(), + max_request_size: 256, + max_response_size: 2 * 256, + request_timeout: std::time::Duration::from_secs(10), + inbound_queue: Some(tx), + }]; + + let behaviour = RequestResponsesBehaviour::new(configs.into_iter(), peer_store_handle).unwrap(); + + let transport = libp2p::tokio_development_transport(local_key).unwrap(); + let mut swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, local_peer_id).build(); + + swarm.listen_on("/ip6/::1/tcp/0".parse().unwrap()).unwrap(); + + (swarm, peer_store, rx) } async fn initialize_litep2p() -> (Litep2p, RequestResponseHandle) { - let (config, handle) = RequestResponseConfig::new( - litep2p::types::protocol::ProtocolName::from("/request/1"), - 2 * 256, - ); - - let litep2p = Litep2p::new( - ConfigBuilder::new() - .with_tcp(TcpConfig { - listen_address: "/ip6/::1/tcp/0".parse().unwrap(), - yamux_config: Default::default(), - }) - .with_request_response_protocol(config) - .build(), - ) - .await - .unwrap(); - - (litep2p, handle) + let (config, handle) = RequestResponseConfig::new( + litep2p::types::protocol::ProtocolName::from("/request/1"), + 2 * 256, + ); + + let litep2p = Litep2p::new( + ConfigBuilder::new() + .with_tcp(TcpConfig { + listen_address: "/ip6/::1/tcp/0".parse().unwrap(), + yamux_config: Default::default(), + }) + .with_request_response_protocol(config) + .build(), + ) + .await + .unwrap(); + + (litep2p, handle) } #[tokio::test] async fn request_works() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, peer_store, requests) = initialize_libp2p(); - let (mut litep2p, mut handle) = initialize_litep2p().await; - let address = litep2p.listen_addresses().next().unwrap().clone(); - let litep2p_peer = *litep2p.local_peer_id(); - let libp2p_peer = *libp2p.local_peer_id(); - let mut pending_responses = FuturesUnordered::new(); - - tokio::spawn(peer_store.run()); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - event => { - tracing::info!("libp2p: received {event:?}"); - } - } - } - event = litep2p.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { peer, .. } => { - // TODO: zzz - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - handle.send_request(peer, vec![0, 1, 2, 3, 4]).await.unwrap(); - } - event => tracing::info!("litep2p: received {event:?}"), - }, - event = handle.next() => match event.unwrap() { - RequestResponseEvent::ResponseReceived { - peer, - request_id, - response, - } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(response, vec![5, 6, 7, 8, 9]); - assert_eq!(request_id, 0usize); - } - RequestResponseEvent::RequestReceived { - peer, - request_id, - request - } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(request, vec![1, 3, 3, 7]); - handle.send_response(request_id, vec![1, 3, 3, 8]).await.unwrap(); - } - event => tracing::warn!("unhandle event: {event:?}"), - }, - request = requests.recv() => match request { - Ok(request) => { - request.pending_response.send(OutgoingResponse { - result: Ok(vec![5, 6, 7, 8, 9]), - reputation_changes: Vec::new(), - sent_feedback: None - }).unwrap(); - - let (tx, rx) = oneshot::channel(); - pending_responses.push(rx); - - libp2p.behaviour_mut().send_request( - &libp2p::PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap(), - "/request/1", - vec![1, 3, 3, 7], - tx, - IfDisconnected::ImmediateError, - ); - } - Err(error) => { - tracing::error!("failed to read reqeuest: {error:?}") - } - }, - event = pending_responses.select_next_some(), if !pending_responses.is_empty() => { - match event { - Ok(response) => { - assert_eq!(response.unwrap(), vec![1, 3, 3, 8]); - break - } - Err(error) => panic!("failed to receive response from peer: {error:?}"), - } - } - _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { - panic!("failed to receive request in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, peer_store, requests) = initialize_libp2p(); + let (mut litep2p, mut handle) = initialize_litep2p().await; + let address = litep2p.listen_addresses().next().unwrap().clone(); + let litep2p_peer = *litep2p.local_peer_id(); + let libp2p_peer = *libp2p.local_peer_id(); + let mut pending_responses = FuturesUnordered::new(); + + tokio::spawn(peer_store.run()); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + event => { + tracing::info!("libp2p: received {event:?}"); + } + } + } + event = litep2p.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { peer, .. } => { + // TODO: zzz + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + handle.send_request(peer, vec![0, 1, 2, 3, 4]).await.unwrap(); + } + event => tracing::info!("litep2p: received {event:?}"), + }, + event = handle.next() => match event.unwrap() { + RequestResponseEvent::ResponseReceived { + peer, + request_id, + response, + } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(response, vec![5, 6, 7, 8, 9]); + assert_eq!(request_id, 0usize); + } + RequestResponseEvent::RequestReceived { + peer, + request_id, + request + } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(request, vec![1, 3, 3, 7]); + handle.send_response(request_id, vec![1, 3, 3, 8]).await.unwrap(); + } + event => tracing::warn!("unhandle event: {event:?}"), + }, + request = requests.recv() => match request { + Ok(request) => { + request.pending_response.send(OutgoingResponse { + result: Ok(vec![5, 6, 7, 8, 9]), + reputation_changes: Vec::new(), + sent_feedback: None + }).unwrap(); + + let (tx, rx) = oneshot::channel(); + pending_responses.push(rx); + + libp2p.behaviour_mut().send_request( + &libp2p::PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap(), + "/request/1", + vec![1, 3, 3, 7], + tx, + IfDisconnected::ImmediateError, + ); + } + Err(error) => { + tracing::error!("failed to read reqeuest: {error:?}") + } + }, + event = pending_responses.select_next_some(), if !pending_responses.is_empty() => { + match event { + Ok(response) => { + assert_eq!(response.unwrap(), vec![1, 3, 3, 8]); + break + } + Err(error) => panic!("failed to receive response from peer: {error:?}"), + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + panic!("failed to receive request in time"); + } + } + } } #[tokio::test] async fn substrate_reject_request() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, peer_store, requests) = initialize_libp2p(); - let (mut litep2p, mut handle) = initialize_litep2p().await; - let address = litep2p.listen_addresses().next().unwrap().clone(); - let libp2p_peer = *libp2p.local_peer_id(); - - tokio::spawn(peer_store.run()); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - event => { - tracing::info!("libp2p: received {event:?}"); - } - } - } - event = litep2p.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { peer, .. } => { - // TODO: zzz - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - handle.send_request(peer, vec![0, 1, 2, 3, 4]).await.unwrap(); - } - event => tracing::info!("litep2p: received {event:?}"), - }, - event = handle.next() => match event.unwrap() { - RequestResponseEvent::RequestFailed { peer, error, .. } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(error, RequestResponseError::Rejected); - break; - } - event => tracing::warn!("unhandle event: {event:?}"), - }, - request = requests.recv() => match request { - Ok(request) => { - drop(request); - } - Err(error) => { - tracing::error!("failed to read reqeuest: {error:?}") - } - }, - _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { - panic!("failed to receive request in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, peer_store, requests) = initialize_libp2p(); + let (mut litep2p, mut handle) = initialize_litep2p().await; + let address = litep2p.listen_addresses().next().unwrap().clone(); + let libp2p_peer = *libp2p.local_peer_id(); + + tokio::spawn(peer_store.run()); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + event => { + tracing::info!("libp2p: received {event:?}"); + } + } + } + event = litep2p.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { peer, .. } => { + // TODO: zzz + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + handle.send_request(peer, vec![0, 1, 2, 3, 4]).await.unwrap(); + } + event => tracing::info!("litep2p: received {event:?}"), + }, + event = handle.next() => match event.unwrap() { + RequestResponseEvent::RequestFailed { peer, error, .. } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(error, RequestResponseError::Rejected); + break; + } + event => tracing::warn!("unhandle event: {event:?}"), + }, + request = requests.recv() => match request { + Ok(request) => { + drop(request); + } + Err(error) => { + tracing::error!("failed to read reqeuest: {error:?}") + } + }, + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + panic!("failed to receive request in time"); + } + } + } } #[tokio::test] async fn litep2p_reject_request() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, peer_store, _) = initialize_libp2p(); - let (mut litep2p, mut handle) = initialize_litep2p().await; - let address = litep2p.listen_addresses().next().unwrap().clone(); - let litep2p_peer = *litep2p.local_peer_id(); - let mut pending_responses = FuturesUnordered::new(); - - tokio::spawn(peer_store.run()); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - SwarmEvent::ConnectionEstablished { .. } => { - let (tx, rx) = oneshot::channel(); - pending_responses.push(rx); - - libp2p.behaviour_mut().send_request( - &libp2p::PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap(), - "/request/1", - vec![1, 3, 3, 7], - tx, - IfDisconnected::ImmediateError, - ); - } - event => { - tracing::info!("libp2p: received {event:?}"); - } - } - } - event = litep2p.next_event() => match event.unwrap() { - event => tracing::info!("litep2p: received {event:?}"), - }, - event = handle.next() => match event.unwrap() { - RequestResponseEvent::RequestReceived { - request_id, - .. - } => { - handle.reject_request(request_id).await; - } - event => tracing::warn!("unhandle event: {event:?}"), - }, - event = pending_responses.select_next_some(), if !pending_responses.is_empty() => { - match event { - Ok(response) => { - assert!(std::matches!(response, Err(RequestFailure::Refused))); - break - } - Err(_) => panic!("failed to read response"), - } - } - _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { - panic!("failed to receive request in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, peer_store, _) = initialize_libp2p(); + let (mut litep2p, mut handle) = initialize_litep2p().await; + let address = litep2p.listen_addresses().next().unwrap().clone(); + let litep2p_peer = *litep2p.local_peer_id(); + let mut pending_responses = FuturesUnordered::new(); + + tokio::spawn(peer_store.run()); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + SwarmEvent::ConnectionEstablished { .. } => { + let (tx, rx) = oneshot::channel(); + pending_responses.push(rx); + + libp2p.behaviour_mut().send_request( + &libp2p::PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap(), + "/request/1", + vec![1, 3, 3, 7], + tx, + IfDisconnected::ImmediateError, + ); + } + event => { + tracing::info!("libp2p: received {event:?}"); + } + } + } + event = litep2p.next_event() => match event.unwrap() { + event => tracing::info!("litep2p: received {event:?}"), + }, + event = handle.next() => match event.unwrap() { + RequestResponseEvent::RequestReceived { + request_id, + .. + } => { + handle.reject_request(request_id).await; + } + event => tracing::warn!("unhandle event: {event:?}"), + }, + event = pending_responses.select_next_some(), if !pending_responses.is_empty() => { + match event { + Ok(response) => { + assert!(std::matches!(response, Err(RequestFailure::Refused))); + break + } + Err(_) => panic!("failed to read response"), + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + panic!("failed to receive request in time"); + } + } + } } #[tokio::test] async fn substrate_request_timeout() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, peer_store, requests) = initialize_libp2p(); - let (mut litep2p, mut handle) = initialize_litep2p().await; - let address = litep2p.listen_addresses().next().unwrap().clone(); - let libp2p_peer = *libp2p.local_peer_id(); - let mut _timeout_request = None; - - tokio::spawn(peer_store.run()); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - event => { - tracing::info!("libp2p: received {event:?}"); - } - } - } - event = litep2p.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { peer, .. } => { - // TODO: zzz - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - handle.send_request(peer, vec![0, 1, 2, 3, 4]).await.unwrap(); - } - event => tracing::info!("litep2p: received {event:?}"), - }, - event = handle.next() => match event.unwrap() { - RequestResponseEvent::RequestFailed { peer, error, .. } => { - assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); - assert_eq!(error, RequestResponseError::Timeout); - break; - } - event => tracing::warn!("unhandle event: {event:?}"), - }, - request = requests.recv() => match request { - Ok(request) => { - _timeout_request = Some(request); - } - Err(error) => { - tracing::error!("failed to read reqeuest: {error:?}") - } - }, - _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { - panic!("failed to receive request in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, peer_store, requests) = initialize_libp2p(); + let (mut litep2p, mut handle) = initialize_litep2p().await; + let address = litep2p.listen_addresses().next().unwrap().clone(); + let libp2p_peer = *libp2p.local_peer_id(); + let mut _timeout_request = None; + + tokio::spawn(peer_store.run()); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + event => { + tracing::info!("libp2p: received {event:?}"); + } + } + } + event = litep2p.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { peer, .. } => { + // TODO: zzz + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + handle.send_request(peer, vec![0, 1, 2, 3, 4]).await.unwrap(); + } + event => tracing::info!("litep2p: received {event:?}"), + }, + event = handle.next() => match event.unwrap() { + RequestResponseEvent::RequestFailed { peer, error, .. } => { + assert_eq!(peer.to_bytes(), libp2p_peer.to_bytes()); + assert_eq!(error, RequestResponseError::Timeout); + break; + } + event => tracing::warn!("unhandle event: {event:?}"), + }, + request = requests.recv() => match request { + Ok(request) => { + _timeout_request = Some(request); + } + Err(error) => { + tracing::error!("failed to read reqeuest: {error:?}") + } + }, + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + panic!("failed to receive request in time"); + } + } + } } #[tokio::test] async fn litep2p_request_timeout() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut libp2p, peer_store, _) = initialize_libp2p(); - let (mut litep2p, mut handle) = initialize_litep2p().await; - let address = litep2p.listen_addresses().next().unwrap().clone(); - let litep2p_peer = *litep2p.local_peer_id(); - let mut pending_responses = FuturesUnordered::new(); - - tokio::spawn(peer_store.run()); - libp2p.dial(address).unwrap(); - - loop { - tokio::select! { - event = libp2p.select_next_some() => { - match event { - SwarmEvent::NewListenAddr { address, .. } => { - tracing::info!("Listening on {address:?}") - } - SwarmEvent::ConnectionEstablished { .. } => { - let (tx, rx) = oneshot::channel(); - pending_responses.push(rx); - - libp2p.behaviour_mut().send_request( - &libp2p::PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap(), - "/request/1", - vec![1, 3, 3, 7], - tx, - IfDisconnected::ImmediateError, - ); - } - event => { - tracing::info!("libp2p: received {event:?}"); - } - } - } - event = litep2p.next_event() => match event.unwrap() { - event => tracing::info!("litep2p: received {event:?}"), - }, - event = handle.next() => match event.unwrap() { - RequestResponseEvent::RequestReceived { .. } => {}, - event => tracing::warn!("unhandle event: {event:?}"), - }, - event = pending_responses.select_next_some(), if !pending_responses.is_empty() => { - match event { - Ok(response) => { - assert!(std::matches!(response, Err(RequestFailure::Network(OutboundFailure::Timeout)))); - break - } - Err(_) => panic!("failed to read response"), - } - } - _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { - panic!("failed to receive request in time"); - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut libp2p, peer_store, _) = initialize_libp2p(); + let (mut litep2p, mut handle) = initialize_litep2p().await; + let address = litep2p.listen_addresses().next().unwrap().clone(); + let litep2p_peer = *litep2p.local_peer_id(); + let mut pending_responses = FuturesUnordered::new(); + + tokio::spawn(peer_store.run()); + libp2p.dial(address).unwrap(); + + loop { + tokio::select! { + event = libp2p.select_next_some() => { + match event { + SwarmEvent::NewListenAddr { address, .. } => { + tracing::info!("Listening on {address:?}") + } + SwarmEvent::ConnectionEstablished { .. } => { + let (tx, rx) = oneshot::channel(); + pending_responses.push(rx); + + libp2p.behaviour_mut().send_request( + &libp2p::PeerId::from_bytes(&litep2p_peer.to_bytes()).unwrap(), + "/request/1", + vec![1, 3, 3, 7], + tx, + IfDisconnected::ImmediateError, + ); + } + event => { + tracing::info!("libp2p: received {event:?}"); + } + } + } + event = litep2p.next_event() => match event.unwrap() { + event => tracing::info!("litep2p: received {event:?}"), + }, + event = handle.next() => match event.unwrap() { + RequestResponseEvent::RequestReceived { .. } => {}, + event => tracing::warn!("unhandle event: {event:?}"), + }, + event = pending_responses.select_next_some(), if !pending_responses.is_empty() => { + match event { + Ok(response) => { + assert!(std::matches!(response, Err(RequestFailure::Network(OutboundFailure::Timeout)))); + break + } + Err(_) => panic!("failed to read response"), + } + } + _ = tokio::time::sleep(std::time::Duration::from_secs(30)) => { + panic!("failed to receive request in time"); + } + } + } } diff --git a/tests/connection/mod.rs b/tests/connection/mod.rs index 3e4a8043..6699706b 100644 --- a/tests/connection/mod.rs +++ b/tests/connection/mod.rs @@ -19,15 +19,15 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - error::{AddressError, Error}, - protocol::libp2p::ping::{Config as PingConfig, PingEvent}, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - websocket::config::Config as WebSocketConfig, - }, - Litep2p, Litep2pEvent, PeerId, + config::ConfigBuilder, + crypto::ed25519::Keypair, + error::{AddressError, Error}, + protocol::libp2p::ping::{Config as PingConfig, PingEvent}, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + websocket::config::Config as WebSocketConfig, + }, + Litep2p, Litep2pEvent, PeerId, }; use futures::{Stream, StreamExt}; @@ -40,1239 +40,1315 @@ use tokio::net::{TcpListener, UdpSocket}; mod protocol_dial_invalid_address; enum Transport { - Tcp(TcpConfig), - Quic(QuicConfig), - WebSocket(WebSocketConfig), + Tcp(TcpConfig), + Quic(QuicConfig), + WebSocket(WebSocketConfig), } #[tokio::test] async fn two_litep2ps_work_tcp() { - two_litep2ps_work( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + two_litep2ps_work( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn two_litep2ps_work_quic() { - two_litep2ps_work(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + two_litep2ps_work( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn two_litep2ps_work_websocket() { - two_litep2ps_work( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + two_litep2ps_work( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn two_litep2ps_work(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, _ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (ping_config2, _ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.dial_address(address).await.unwrap(); - - let (res1, res2) = tokio::join!(litep2p1.next_event(), litep2p2.next_event()); - - assert!(std::matches!(res1, Some(Litep2pEvent::ConnectionEstablished { .. }))); - assert!(std::matches!(res2, Some(Litep2pEvent::ConnectionEstablished { .. }))); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, _ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (ping_config2, _ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.dial_address(address).await.unwrap(); + + let (res1, res2) = tokio::join!(litep2p1.next_event(), litep2p2.next_event()); + + assert!(std::matches!( + res1, + Some(Litep2pEvent::ConnectionEstablished { .. }) + )); + assert!(std::matches!( + res2, + Some(Litep2pEvent::ConnectionEstablished { .. }) + )); } #[tokio::test] async fn dial_failure_tcp() { - dial_failure( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Multiaddr::empty() - .with(Protocol::Ip6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(1)), - ) - .await + dial_failure( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Multiaddr::empty() + .with(Protocol::Ip6(std::net::Ipv6Addr::new( + 0, 0, 0, 0, 0, 0, 0, 1, + ))) + .with(Protocol::Tcp(1)), + ) + .await } #[tokio::test] async fn dial_failure_quic() { - dial_failure( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - Multiaddr::empty() - .with(Protocol::Ip6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Udp(1)) - .with(Protocol::QuicV1), - ) - .await; + dial_failure( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + Multiaddr::empty() + .with(Protocol::Ip6(std::net::Ipv6Addr::new( + 0, 0, 0, 0, 0, 0, 0, 1, + ))) + .with(Protocol::Udp(1)) + .with(Protocol::QuicV1), + ) + .await; } #[tokio::test] async fn dial_failure_websocket() { - dial_failure( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Multiaddr::empty() - .with(Protocol::Ip6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(1)) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), - ) - .await; + dial_failure( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Multiaddr::empty() + .with(Protocol::Ip6(std::net::Ipv6Addr::new( + 0, 0, 0, 0, 0, 0, 0, 1, + ))) + .with(Protocol::Tcp(1)) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), + ) + .await; } async fn dial_failure(transport1: Transport, transport2: Transport, dial_address: Multiaddr) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, _ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (ping_config2, _ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address = dial_address - .with(Protocol::P2p(Multihash::from_bytes(&litep2p2.local_peer_id().to_bytes()).unwrap())); - - litep2p1.dial_address(address).await.unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p2.next_event().await; - } - }); - - assert!(std::matches!(litep2p1.next_event().await, Some(Litep2pEvent::DialFailure { .. }))); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, _ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (ping_config2, _ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address = dial_address.with(Protocol::P2p( + Multihash::from_bytes(&litep2p2.local_peer_id().to_bytes()).unwrap(), + )); + + litep2p1.dial_address(address).await.unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p2.next_event().await; + } + }); + + assert!(std::matches!( + litep2p1.next_event().await, + Some(Litep2pEvent::DialFailure { .. }) + )); } #[tokio::test] async fn connect_over_dns() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let keypair1 = Keypair::generate(); - let (ping_config1, _ping_event_stream1) = PingConfig::default(); - - let config1 = ConfigBuilder::new() - .with_keypair(keypair1) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - - let keypair2 = Keypair::generate(); - let (ping_config2, _ping_event_stream2) = PingConfig::default(); - - let config2 = ConfigBuilder::new() - .with_keypair(keypair2) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config2) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - let address = litep2p2.listen_addresses().next().unwrap().clone(); - let tcp = address.iter().skip(1).next().unwrap(); - - let mut new_address = Multiaddr::empty(); - new_address.push(Protocol::Dns("localhost".into())); - new_address.push(tcp); - new_address.push(Protocol::P2p(Multihash::from_bytes(&peer2.to_bytes()).unwrap())); - - litep2p1.dial_address(new_address).await.unwrap(); - let (res1, res2) = tokio::join!(litep2p1.next_event(), litep2p2.next_event()); - - assert!(std::matches!(res1, Some(Litep2pEvent::ConnectionEstablished { .. }))); - assert!(std::matches!(res2, Some(Litep2pEvent::ConnectionEstablished { .. }))); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let keypair1 = Keypair::generate(); + let (ping_config1, _ping_event_stream1) = PingConfig::default(); + + let config1 = ConfigBuilder::new() + .with_keypair(keypair1) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + + let keypair2 = Keypair::generate(); + let (ping_config2, _ping_event_stream2) = PingConfig::default(); + + let config2 = ConfigBuilder::new() + .with_keypair(keypair2) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config2) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + let address = litep2p2.listen_addresses().next().unwrap().clone(); + let tcp = address.iter().skip(1).next().unwrap(); + + let mut new_address = Multiaddr::empty(); + new_address.push(Protocol::Dns("localhost".into())); + new_address.push(tcp); + new_address.push(Protocol::P2p( + Multihash::from_bytes(&peer2.to_bytes()).unwrap(), + )); + + litep2p1.dial_address(new_address).await.unwrap(); + let (res1, res2) = tokio::join!(litep2p1.next_event(), litep2p2.next_event()); + + assert!(std::matches!( + res1, + Some(Litep2pEvent::ConnectionEstablished { .. }) + )); + assert!(std::matches!( + res2, + Some(Litep2pEvent::ConnectionEstablished { .. }) + )); } #[tokio::test] async fn connection_timeout_tcp() { - // create tcp listener but don't accept any inbound connections - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - - connection_timeout( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - address, - ) - .await + // create tcp listener but don't accept any inbound connections + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + + connection_timeout( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + address, + ) + .await } #[tokio::test] async fn connection_timeout_quic() { - // create udp socket but don't respond to any inbound datagrams - let listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Udp(address.port())) - .with(Protocol::QuicV1) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - - connection_timeout(Transport::Quic(Default::default()), address).await; + // create udp socket but don't respond to any inbound datagrams + let listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Udp(address.port())) + .with(Protocol::QuicV1) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + + connection_timeout(Transport::Quic(Default::default()), address).await; } #[tokio::test] async fn connection_timeout_websocket() { - // create tcp listener but don't accept any inbound connections - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - - connection_timeout( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - address, - ) - .await; + // create tcp listener but don't accept any inbound connections + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + + connection_timeout( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + address, + ) + .await; } async fn connection_timeout(transport: Transport, address: Multiaddr) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config, _ping_event_stream) = PingConfig::default(); - let litep2p_config = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config); - - let litep2p_config = match transport { - Transport::Tcp(config) => litep2p_config.with_tcp(config), - Transport::Quic(config) => litep2p_config.with_quic(config), - Transport::WebSocket(config) => litep2p_config.with_websocket(config), - } - .build(); - - let mut litep2p = Litep2p::new(litep2p_config).unwrap(); - - litep2p.dial_address(address.clone()).await.unwrap(); - - let Some(Litep2pEvent::DialFailure { address: dial_address, error }) = - litep2p.next_event().await - else { - panic!("invalid event received"); - }; - - assert_eq!(dial_address, address); - println!("{error:?}"); - assert!(std::matches!(error, Error::Timeout)); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config, _ping_event_stream) = PingConfig::default(); + let litep2p_config = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config); + + let litep2p_config = match transport { + Transport::Tcp(config) => litep2p_config.with_tcp(config), + Transport::Quic(config) => litep2p_config.with_quic(config), + Transport::WebSocket(config) => litep2p_config.with_websocket(config), + } + .build(); + + let mut litep2p = Litep2p::new(litep2p_config).unwrap(); + + litep2p.dial_address(address.clone()).await.unwrap(); + + let Some(Litep2pEvent::DialFailure { + address: dial_address, + error, + }) = litep2p.next_event().await + else { + panic!("invalid event received"); + }; + + assert_eq!(dial_address, address); + println!("{error:?}"); + assert!(std::matches!(error, Error::Timeout)); } #[tokio::test] async fn dial_quic_peer_id_missing() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config, _ping_event_stream) = PingConfig::default(); - let config = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config) - .build(); - - let mut litep2p = Litep2p::new(config).unwrap(); - - // create udp socket but don't respond to any inbound datagrams - let listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Udp(address.port())) - .with(Protocol::QuicV1); - - match litep2p.dial_address(address.clone()).await { - Err(Error::AddressError(AddressError::PeerIdMissing)) => {}, - state => panic!("dial not supposed to succeed {state:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config, _ping_event_stream) = PingConfig::default(); + let config = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_quic(Default::default()) + .with_libp2p_ping(ping_config) + .build(); + + let mut litep2p = Litep2p::new(config).unwrap(); + + // create udp socket but don't respond to any inbound datagrams + let listener = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + let address = Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Udp(address.port())) + .with(Protocol::QuicV1); + + match litep2p.dial_address(address.clone()).await { + Err(Error::AddressError(AddressError::PeerIdMissing)) => {} + state => panic!("dial not supposed to succeed {state:?}"), + } } #[tokio::test] async fn dial_self_tcp() { - dial_self(Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - })) - .await + dial_self(Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + })) + .await } #[tokio::test] async fn dial_self_quic() { - dial_self(Transport::Quic(Default::default())).await; + dial_self(Transport::Quic(Default::default())).await; } #[tokio::test] async fn dial_self_websocket() { - dial_self(Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - })) - .await; + dial_self(Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + })) + .await; } async fn dial_self(transport: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config, _ping_event_stream) = PingConfig::default(); - let litep2p_config = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config); - - let litep2p_config = match transport { - Transport::Tcp(config) => litep2p_config.with_tcp(config), - Transport::Quic(config) => litep2p_config.with_quic(config), - Transport::WebSocket(config) => litep2p_config.with_websocket(config), - } - .build(); - - let mut litep2p = Litep2p::new(litep2p_config).unwrap(); - let address = litep2p.listen_addresses().next().unwrap().clone(); - - // dial without peer id attached - assert!(std::matches!( - litep2p.dial_address(address.clone()).await, - Err(Error::TriedToDialSelf) - )); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config, _ping_event_stream) = PingConfig::default(); + let litep2p_config = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config); + + let litep2p_config = match transport { + Transport::Tcp(config) => litep2p_config.with_tcp(config), + Transport::Quic(config) => litep2p_config.with_quic(config), + Transport::WebSocket(config) => litep2p_config.with_websocket(config), + } + .build(); + + let mut litep2p = Litep2p::new(litep2p_config).unwrap(); + let address = litep2p.listen_addresses().next().unwrap().clone(); + + // dial without peer id attached + assert!(std::matches!( + litep2p.dial_address(address.clone()).await, + Err(Error::TriedToDialSelf) + )); } #[tokio::test] async fn attempt_to_dial_using_unsupported_transport() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config, _ping_event_stream) = PingConfig::default(); - let config = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config) - .build(); - - let mut litep2p = Litep2p::new(config).unwrap(); - let address = Multiaddr::empty() - .with(Protocol::from(std::net::Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - - assert!(std::matches!( - litep2p.dial_address(address.clone()).await, - Err(Error::TransportNotSupported(_)) - )); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config, _ping_event_stream) = PingConfig::default(); + let config = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_quic(Default::default()) + .with_libp2p_ping(ping_config) + .build(); + + let mut litep2p = Litep2p::new(config).unwrap(); + let address = Multiaddr::empty() + .with(Protocol::from(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + + assert!(std::matches!( + litep2p.dial_address(address.clone()).await, + Err(Error::TransportNotSupported(_)) + )); } #[tokio::test] async fn keep_alive_timeout_tcp() { - keep_alive_timeout( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + keep_alive_timeout( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn keep_alive_timeout_quic() { - keep_alive_timeout(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + keep_alive_timeout( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn keep_alive_timeout_websocket() { - keep_alive_timeout( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + keep_alive_timeout( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn keep_alive_timeout(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address1 = litep2p1.listen_addresses().next().unwrap().clone(); - litep2p2.dial_address(address1).await.unwrap(); - let mut litep2p1_ping = false; - let mut litep2p2_ping = false; - - loop { - tokio::select! { - event = litep2p1.next_event() => match event { - Some(Litep2pEvent::ConnectionClosed { .. }) if litep2p1_ping || litep2p2_ping => { - break; - } - _ => {} - }, - event = litep2p2.next_event() => match event { - Some(Litep2pEvent::ConnectionClosed { .. }) if litep2p1_ping || litep2p2_ping => { - break; - } - _ => {} - }, - _event = ping_event_stream1.next() => { - tracing::warn!("ping1 received"); - litep2p1_ping = true; - } - _event = ping_event_stream2.next() => { - tracing::warn!("ping2 received"); - litep2p2_ping = true; - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address1 = litep2p1.listen_addresses().next().unwrap().clone(); + litep2p2.dial_address(address1).await.unwrap(); + let mut litep2p1_ping = false; + let mut litep2p2_ping = false; + + loop { + tokio::select! { + event = litep2p1.next_event() => match event { + Some(Litep2pEvent::ConnectionClosed { .. }) if litep2p1_ping || litep2p2_ping => { + break; + } + _ => {} + }, + event = litep2p2.next_event() => match event { + Some(Litep2pEvent::ConnectionClosed { .. }) if litep2p1_ping || litep2p2_ping => { + break; + } + _ => {} + }, + _event = ping_event_stream1.next() => { + tracing::warn!("ping1 received"); + litep2p1_ping = true; + } + _event = ping_event_stream2.next() => { + tracing::warn!("ping2 received"); + litep2p2_ping = true; + } + } + } } #[tokio::test] async fn simultaneous_dial_tcp() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config2) - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address1 = litep2p1.listen_addresses().next().unwrap().clone(); - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - - let (res1, res2) = - tokio::join!(litep2p1.dial_address(address2), litep2p2.dial_address(address1)); - assert!(std::matches!((res1, res2), (Ok(()), Ok(())))); - - let mut ping_received1 = false; - let mut ping_received2 = false; - - while !ping_received1 || !ping_received2 { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - if event.is_some() { - ping_received1 = true; - } - } - event = ping_event_stream2.next() => { - if event.is_some() { - ping_received2 = true; - } - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config2) + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address1 = litep2p1.listen_addresses().next().unwrap().clone(); + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + + let (res1, res2) = tokio::join!( + litep2p1.dial_address(address2), + litep2p2.dial_address(address1) + ); + assert!(std::matches!((res1, res2), (Ok(()), Ok(())))); + + let mut ping_received1 = false; + let mut ping_received2 = false; + + while !ping_received1 || !ping_received2 { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + if event.is_some() { + ping_received1 = true; + } + } + event = ping_event_stream2.next() => { + if event.is_some() { + ping_received2 = true; + } + } + } + } } #[tokio::test] async fn simultaneous_dial_quic() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config1) - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config2) - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address1 = litep2p1.listen_addresses().next().unwrap().clone(); - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - - let (res1, res2) = - tokio::join!(litep2p1.dial_address(address2), litep2p2.dial_address(address1)); - assert!(std::matches!((res1, res2), (Ok(()), Ok(())))); - - let mut ping_received1 = false; - let mut ping_received2 = false; - - while !ping_received1 || !ping_received2 { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - if event.is_some() { - ping_received1 = true; - } - } - event = ping_event_stream2.next() => { - if event.is_some() { - ping_received2 = true; - } - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_quic(Default::default()) + .with_libp2p_ping(ping_config1) + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_quic(Default::default()) + .with_libp2p_ping(ping_config2) + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address1 = litep2p1.listen_addresses().next().unwrap().clone(); + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + + let (res1, res2) = tokio::join!( + litep2p1.dial_address(address2), + litep2p2.dial_address(address1) + ); + assert!(std::matches!((res1, res2), (Ok(()), Ok(())))); + + let mut ping_received1 = false; + let mut ping_received2 = false; + + while !ping_received1 || !ping_received2 { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + if event.is_some() { + ping_received1 = true; + } + } + event = ping_event_stream2.next() => { + if event.is_some() { + ping_received2 = true; + } + } + } + } } #[tokio::test] async fn simultaneous_dial_ipv6_quic() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config1) - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_quic(Default::default()) - .with_libp2p_ping(ping_config2) - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address1 = litep2p1.listen_addresses().next().unwrap().clone(); - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - - let (res1, res2) = - tokio::join!(litep2p1.dial_address(address2), litep2p2.dial_address(address1)); - assert!(std::matches!((res1, res2), (Ok(()), Ok(())))); - - let mut ping_received1 = false; - let mut ping_received2 = false; - - while !ping_received1 || !ping_received2 { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - if event.is_some() { - ping_received1 = true; - } - } - event = ping_event_stream2.next() => { - if event.is_some() { - ping_received2 = true; - } - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_quic(Default::default()) + .with_libp2p_ping(ping_config1) + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_quic(Default::default()) + .with_libp2p_ping(ping_config2) + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address1 = litep2p1.listen_addresses().next().unwrap().clone(); + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + + let (res1, res2) = tokio::join!( + litep2p1.dial_address(address2), + litep2p2.dial_address(address1) + ); + assert!(std::matches!((res1, res2), (Ok(()), Ok(())))); + + let mut ping_received1 = false; + let mut ping_received2 = false; + + while !ping_received1 || !ping_received2 { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + if event.is_some() { + ping_received1 = true; + } + } + event = ping_event_stream2.next() => { + if event.is_some() { + ping_received2 = true; + } + } + } + } } #[tokio::test] async fn websocket_over_ipv6() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_websocket(WebSocketConfig { - listen_addresses: vec!["/ip6/::1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_websocket(WebSocketConfig { - listen_addresses: vec!["/ip6/::1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config2) - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.dial_address(address2).await.unwrap(); - - let mut ping_received1 = false; - let mut ping_received2 = false; - - while !ping_received1 || !ping_received2 { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - if event.is_some() { - ping_received1 = true; - } - } - event = ping_event_stream2.next() => { - if event.is_some() { - ping_received2 = true; - } - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_websocket(WebSocketConfig { + listen_addresses: vec!["/ip6/::1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_websocket(WebSocketConfig { + listen_addresses: vec!["/ip6/::1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config2) + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.dial_address(address2).await.unwrap(); + + let mut ping_received1 = false; + let mut ping_received2 = false; + + while !ping_received1 || !ping_received2 { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + if event.is_some() { + ping_received1 = true; + } + } + event = ping_event_stream2.next() => { + if event.is_some() { + ping_received2 = true; + } + } + } + } } #[tokio::test] async fn tcp_dns_resolution() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config2) - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address = litep2p2.listen_addresses().next().unwrap().clone(); - let tcp = address.iter().skip(1).next().unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - let mut new_address = Multiaddr::empty(); - new_address.push(Protocol::Dns("localhost".into())); - new_address.push(tcp); - new_address.push(Protocol::P2p(Multihash::from_bytes(&peer2.to_bytes()).unwrap())); - litep2p1.dial_address(new_address).await.unwrap(); - - let mut ping_received1 = false; - let mut ping_received2 = false; - - while !ping_received1 || !ping_received2 { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - if event.is_some() { - ping_received1 = true; - } - } - event = ping_event_stream2.next() => { - if event.is_some() { - ping_received2 = true; - } - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config2) + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address = litep2p2.listen_addresses().next().unwrap().clone(); + let tcp = address.iter().skip(1).next().unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + let mut new_address = Multiaddr::empty(); + new_address.push(Protocol::Dns("localhost".into())); + new_address.push(tcp); + new_address.push(Protocol::P2p( + Multihash::from_bytes(&peer2.to_bytes()).unwrap(), + )); + litep2p1.dial_address(new_address).await.unwrap(); + + let mut ping_received1 = false; + let mut ping_received2 = false; + + while !ping_received1 || !ping_received2 { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + if event.is_some() { + ping_received1 = true; + } + } + event = ping_event_stream2.next() => { + if event.is_some() { + ping_received2 = true; + } + } + } + } } #[tokio::test] async fn websocket_dns_resolution() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, mut ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_websocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - let (ping_config2, mut ping_event_stream2) = PingConfig::default(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_websocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_ping(ping_config2) - .build(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address = litep2p2.listen_addresses().next().unwrap().clone(); - let tcp = address.iter().skip(1).next().unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - let mut new_address = Multiaddr::empty(); - new_address.push(Protocol::Dns("localhost".into())); - new_address.push(tcp); - new_address.push(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))); - new_address.push(Protocol::P2p(Multihash::from_bytes(&peer2.to_bytes()).unwrap())); - litep2p1.dial_address(new_address).await.unwrap(); - - let mut ping_received1 = false; - let mut ping_received2 = false; - - while !ping_received1 || !ping_received2 { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - if event.is_some() { - ping_received1 = true; - } - } - event = ping_event_stream2.next() => { - if event.is_some() { - ping_received2 = true; - } - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, mut ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_websocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + let (ping_config2, mut ping_event_stream2) = PingConfig::default(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_websocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_ping(ping_config2) + .build(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address = litep2p2.listen_addresses().next().unwrap().clone(); + let tcp = address.iter().skip(1).next().unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + let mut new_address = Multiaddr::empty(); + new_address.push(Protocol::Dns("localhost".into())); + new_address.push(tcp); + new_address.push(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))); + new_address.push(Protocol::P2p( + Multihash::from_bytes(&peer2.to_bytes()).unwrap(), + )); + litep2p1.dial_address(new_address).await.unwrap(); + + let mut ping_received1 = false; + let mut ping_received2 = false; + + while !ping_received1 || !ping_received2 { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + if event.is_some() { + ping_received1 = true; + } + } + event = ping_event_stream2.next() => { + if event.is_some() { + ping_received2 = true; + } + } + } + } } #[tokio::test] async fn multiple_listen_addresses_tcp() { - multiple_listen_addresses( - Transport::Tcp(TcpConfig { - listen_addresses: vec![ - "/ip6/::1/tcp/0".parse().unwrap(), - "/ip4/127.0.0.1/tcp/0".parse().unwrap(), - ], - ..Default::default() - }), - Transport::Tcp(TcpConfig { listen_addresses: vec![], ..Default::default() }), - Transport::Tcp(TcpConfig { listen_addresses: vec![], ..Default::default() }), - ) - .await + multiple_listen_addresses( + Transport::Tcp(TcpConfig { + listen_addresses: vec![ + "/ip6/::1/tcp/0".parse().unwrap(), + "/ip4/127.0.0.1/tcp/0".parse().unwrap(), + ], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec![], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec![], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn multiple_listen_addresses_quic() { - multiple_listen_addresses( - Transport::Quic(QuicConfig { - listen_addresses: vec![ - "/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap(), - "/ip6/::1/udp/0/quic-v1".parse().unwrap(), - ], - ..Default::default() - }), - Transport::Quic(QuicConfig { listen_addresses: vec![], ..Default::default() }), - Transport::Quic(QuicConfig { listen_addresses: vec![], ..Default::default() }), - ) - .await; + multiple_listen_addresses( + Transport::Quic(QuicConfig { + listen_addresses: vec![ + "/ip4/127.0.0.1/udp/0/quic-v1".parse().unwrap(), + "/ip6/::1/udp/0/quic-v1".parse().unwrap(), + ], + ..Default::default() + }), + Transport::Quic(QuicConfig { + listen_addresses: vec![], + ..Default::default() + }), + Transport::Quic(QuicConfig { + listen_addresses: vec![], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn multiple_listen_addresses_websocket() { - multiple_listen_addresses( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec![ - "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap(), - "/ip6/::1/tcp/0/ws".parse().unwrap(), - ], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { listen_addresses: vec![], ..Default::default() }), - Transport::WebSocket(WebSocketConfig { listen_addresses: vec![], ..Default::default() }), - ) - .await; + multiple_listen_addresses( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec![ + "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap(), + "/ip6/::1/tcp/0/ws".parse().unwrap(), + ], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec![], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec![], + ..Default::default() + }), + ) + .await; } async fn make_dummy_litep2p( - transport: Transport, + transport: Transport, ) -> (Litep2p, Box + Send + Unpin>) { - let (ping_config, ping_event_stream) = PingConfig::default(); - let litep2p_config = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config); - - let litep2p_config = match transport { - Transport::Tcp(config) => litep2p_config.with_tcp(config), - Transport::Quic(config) => litep2p_config.with_quic(config), - Transport::WebSocket(config) => litep2p_config.with_websocket(config), - } - .build(); - - (Litep2p::new(litep2p_config).unwrap(), ping_event_stream) + let (ping_config, ping_event_stream) = PingConfig::default(); + let litep2p_config = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config); + + let litep2p_config = match transport { + Transport::Tcp(config) => litep2p_config.with_tcp(config), + Transport::Quic(config) => litep2p_config.with_quic(config), + Transport::WebSocket(config) => litep2p_config.with_websocket(config), + } + .build(); + + (Litep2p::new(litep2p_config).unwrap(), ping_event_stream) } async fn multiple_listen_addresses( - transport1: Transport, - transport2: Transport, - transport3: Transport, + transport1: Transport, + transport2: Transport, + transport3: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut litep2p1, _event_stream) = make_dummy_litep2p(transport1).await; - let (mut litep2p2, _event_stream) = make_dummy_litep2p(transport2).await; - let (mut litep2p3, _event_stream) = make_dummy_litep2p(transport3).await; - - let mut address_iter = litep2p1.listen_addresses(); - let address1 = address_iter.next().unwrap().clone(); - let address2 = address_iter.next().unwrap().clone(); - drop(address_iter); - - tokio::spawn(async move { - loop { - let _ = litep2p1.next_event().await; - } - }); - - let (res1, res2) = - tokio::join!(litep2p2.dial_address(address1), litep2p3.dial_address(address2),); - assert!(res1.is_ok() && res2.is_ok()); - - let (res1, res2) = tokio::join!(litep2p2.next_event(), litep2p3.next_event()); - - assert!(std::matches!(res1, Some(Litep2pEvent::ConnectionEstablished { .. }))); - assert!(std::matches!(res2, Some(Litep2pEvent::ConnectionEstablished { .. }))); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut litep2p1, _event_stream) = make_dummy_litep2p(transport1).await; + let (mut litep2p2, _event_stream) = make_dummy_litep2p(transport2).await; + let (mut litep2p3, _event_stream) = make_dummy_litep2p(transport3).await; + + let mut address_iter = litep2p1.listen_addresses(); + let address1 = address_iter.next().unwrap().clone(); + let address2 = address_iter.next().unwrap().clone(); + drop(address_iter); + + tokio::spawn(async move { + loop { + let _ = litep2p1.next_event().await; + } + }); + + let (res1, res2) = tokio::join!( + litep2p2.dial_address(address1), + litep2p3.dial_address(address2), + ); + assert!(res1.is_ok() && res2.is_ok()); + + let (res1, res2) = tokio::join!(litep2p2.next_event(), litep2p3.next_event()); + + assert!(std::matches!( + res1, + Some(Litep2pEvent::ConnectionEstablished { .. }) + )); + assert!(std::matches!( + res2, + Some(Litep2pEvent::ConnectionEstablished { .. }) + )); } #[tokio::test] async fn port_in_use_tcp() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let _litep2p = Litep2p::new( - ConfigBuilder::new() - .with_tcp(TcpConfig { - listen_addresses: vec![Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port()))], - ..Default::default() - }) - .build(), - ) - .unwrap(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let _litep2p = Litep2p::new( + ConfigBuilder::new() + .with_tcp(TcpConfig { + listen_addresses: vec![Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port()))], + ..Default::default() + }) + .build(), + ) + .unwrap(); } #[tokio::test] async fn port_in_use_websocket() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let listener = TcpListener::bind("[::1]:0").await.unwrap(); - let address = listener.local_addr().unwrap(); - - let _litep2p = Litep2p::new( - ConfigBuilder::new() - .with_websocket(WebSocketConfig { - listen_addresses: vec![Multiaddr::empty() - .with(Protocol::from(address.ip())) - .with(Protocol::Tcp(address.port())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string())))], - ..Default::default() - }) - .build(), - ) - .unwrap(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let listener = TcpListener::bind("[::1]:0").await.unwrap(); + let address = listener.local_addr().unwrap(); + + let _litep2p = Litep2p::new( + ConfigBuilder::new() + .with_websocket(WebSocketConfig { + listen_addresses: vec![Multiaddr::empty() + .with(Protocol::from(address.ip())) + .with(Protocol::Tcp(address.port())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string())))], + ..Default::default() + }) + .build(), + ) + .unwrap(); } #[tokio::test] async fn dial_over_multiple_addresses() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - // let (mut litep2p1, _event_stream) = make_dummy_litep2p(transport1).await; - // let (mut litep2p2, _event_stream) = make_dummy_litep2p(transport2).await; - // let (mut litep2p3, _event_stream) = make_dummy_litep2p(transport3).await; - - // let mut address_iter = litep2p1.listen_addresses(); - // let address1 = address_iter.next().unwrap().clone(); - // let address2 = address_iter.next().unwrap().clone(); - // drop(address_iter); - - // tokio::spawn(async move { - // loop { - // let _ = litep2p1.next_event().await; - // } - // }); - - // let (res1, res2) = tokio::join!( - // litep2p2.dial_address(address1), - // litep2p3.dial_address(address2), - // ); - // assert!(res1.is_ok() && res2.is_ok()); - - // let (res1, res2) = tokio::join!(litep2p2.next_event(), litep2p3.next_event()); - - // assert!(std::matches!( - // res1, - // Some(Litep2pEvent::ConnectionEstablished { .. }) - // )); - // assert!(std::matches!( - // res2, - // Some(Litep2pEvent::ConnectionEstablished { .. }) - // )); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + // let (mut litep2p1, _event_stream) = make_dummy_litep2p(transport1).await; + // let (mut litep2p2, _event_stream) = make_dummy_litep2p(transport2).await; + // let (mut litep2p3, _event_stream) = make_dummy_litep2p(transport3).await; + + // let mut address_iter = litep2p1.listen_addresses(); + // let address1 = address_iter.next().unwrap().clone(); + // let address2 = address_iter.next().unwrap().clone(); + // drop(address_iter); + + // tokio::spawn(async move { + // loop { + // let _ = litep2p1.next_event().await; + // } + // }); + + // let (res1, res2) = tokio::join!( + // litep2p2.dial_address(address1), + // litep2p3.dial_address(address2), + // ); + // assert!(res1.is_ok() && res2.is_ok()); + + // let (res1, res2) = tokio::join!(litep2p2.next_event(), litep2p3.next_event()); + + // assert!(std::matches!( + // res1, + // Some(Litep2pEvent::ConnectionEstablished { .. }) + // )); + // assert!(std::matches!( + // res2, + // Some(Litep2pEvent::ConnectionEstablished { .. }) + // )); } #[tokio::test] async fn unspecified_listen_address_tcp() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, _ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec![ - "/ip4/0.0.0.0/tcp/0".parse().unwrap(), - "/ip6/::/tcp/0".parse().unwrap(), - ], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - - for address in litep2p1.listen_addresses() { - tracing::info!("address: {address:?}"); - } - - let listen_address = litep2p1.listen_addresses().collect::>(); - - let ip4_port = listen_address.iter().find_map(|address| { - let mut iter = address.iter(); - match iter.next() { - Some(Protocol::Ip4(_)) => match iter.next() { - Some(Protocol::Tcp(port)) => Some(port), - _ => panic!("invalid protocol"), - }, - _ => None, - } - }); - let ip6_port = listen_address.iter().find_map(|address| { - let mut iter = address.iter(); - match iter.next() { - Some(Protocol::Ip6(_)) => match iter.next() { - Some(Protocol::Tcp(port)) => Some(port), - _ => panic!("invalid protocol"), - }, - _ => None, - } - }); - - tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); - - let network_interfaces = NetworkInterface::show().unwrap(); - for iface in network_interfaces.iter() { - for address in &iface.addr { - let (ping_config2, _ping_event_stream2) = PingConfig::default(); - let config = ConfigBuilder::new().with_libp2p_ping(ping_config2); - - let (mut litep2p, dial_address) = match address { - network_interface::Addr::V4(record) => { - if ip4_port.is_none() { - continue; - } - - let config = config - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .build(); - - ( - Litep2p::new(config).unwrap(), - Multiaddr::empty() - .with(Protocol::Ip4(record.ip)) - .with(Protocol::Tcp(ip4_port.unwrap())) - .with(Protocol::P2p(Multihash::from(peer1))), - ) - }, - network_interface::Addr::V6(record) => { - if record.ip.segments()[0] == 0xfe80 || ip6_port.is_none() { - continue; - } - - let config = config.with_tcp(Default::default()).build(); - - ( - Litep2p::new(config).unwrap(), - Multiaddr::empty() - .with(Protocol::Ip6(record.ip)) - .with(Protocol::Tcp(ip6_port.unwrap())) - .with(Protocol::P2p(Multihash::from(peer1))), - ) - }, - }; - - litep2p.dial_address(dial_address).await.unwrap(); - match litep2p.next_event().await { - Some(Litep2pEvent::ConnectionEstablished { .. }) => {}, - event => panic!("invalid event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, _ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec![ + "/ip4/0.0.0.0/tcp/0".parse().unwrap(), + "/ip6/::/tcp/0".parse().unwrap(), + ], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + + for address in litep2p1.listen_addresses() { + tracing::info!("address: {address:?}"); + } + + let listen_address = litep2p1.listen_addresses().collect::>(); + + let ip4_port = listen_address.iter().find_map(|address| { + let mut iter = address.iter(); + match iter.next() { + Some(Protocol::Ip4(_)) => match iter.next() { + Some(Protocol::Tcp(port)) => Some(port), + _ => panic!("invalid protocol"), + }, + _ => None, + } + }); + let ip6_port = listen_address.iter().find_map(|address| { + let mut iter = address.iter(); + match iter.next() { + Some(Protocol::Ip6(_)) => match iter.next() { + Some(Protocol::Tcp(port)) => Some(port), + _ => panic!("invalid protocol"), + }, + _ => None, + } + }); + + tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); + + let network_interfaces = NetworkInterface::show().unwrap(); + for iface in network_interfaces.iter() { + for address in &iface.addr { + let (ping_config2, _ping_event_stream2) = PingConfig::default(); + let config = ConfigBuilder::new().with_libp2p_ping(ping_config2); + + let (mut litep2p, dial_address) = match address { + network_interface::Addr::V4(record) => { + if ip4_port.is_none() { + continue; + } + + let config = config + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .build(); + + ( + Litep2p::new(config).unwrap(), + Multiaddr::empty() + .with(Protocol::Ip4(record.ip)) + .with(Protocol::Tcp(ip4_port.unwrap())) + .with(Protocol::P2p(Multihash::from(peer1))), + ) + } + network_interface::Addr::V6(record) => { + if record.ip.segments()[0] == 0xfe80 || ip6_port.is_none() { + continue; + } + + let config = config.with_tcp(Default::default()).build(); + + ( + Litep2p::new(config).unwrap(), + Multiaddr::empty() + .with(Protocol::Ip6(record.ip)) + .with(Protocol::Tcp(ip6_port.unwrap())) + .with(Protocol::P2p(Multihash::from(peer1))), + ) + } + }; + + litep2p.dial_address(dial_address).await.unwrap(); + match litep2p.next_event().await { + Some(Litep2pEvent::ConnectionEstablished { .. }) => {} + event => panic!("invalid event: {event:?}"), + } + } + } } #[tokio::test] async fn unspecified_listen_address_websocket() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config1, _ping_event_stream1) = PingConfig::default(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_websocket(WebSocketConfig { - listen_addresses: vec![ - "/ip4/0.0.0.0/tcp/0/ws".parse().unwrap(), - "/ip6/::/tcp/0/ws".parse().unwrap(), - ], - ..Default::default() - }) - .with_libp2p_ping(ping_config1) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - - for address in litep2p1.listen_addresses() { - tracing::info!("address: {address:?}"); - } - - let listen_address = litep2p1.listen_addresses().collect::>(); - - let ip4_port = listen_address.iter().find_map(|address| { - let mut iter = address.iter(); - match iter.next() { - Some(Protocol::Ip4(_)) => match iter.next() { - Some(Protocol::Tcp(port)) => Some(port), - _ => panic!("invalid protocol"), - }, - _ => None, - } - }); - let ip6_port = listen_address.iter().find_map(|address| { - let mut iter = address.iter(); - match iter.next() { - Some(Protocol::Ip6(_)) => match iter.next() { - Some(Protocol::Tcp(port)) => Some(port), - _ => panic!("invalid protocol"), - }, - _ => None, - } - }); - - tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); - - let network_interfaces = NetworkInterface::show().unwrap(); - for iface in network_interfaces.iter() { - for address in &iface.addr { - let (ping_config2, _ping_event_stream2) = PingConfig::default(); - let config = ConfigBuilder::new().with_libp2p_ping(ping_config2); - - let (mut litep2p, dial_address) = match address { - network_interface::Addr::V4(record) => { - if ip4_port.is_none() { - continue; - } - - let config = config - .with_websocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }) - .build(); - - ( - Litep2p::new(config).unwrap(), - Multiaddr::empty() - .with(Protocol::Ip4(record.ip)) - .with(Protocol::Tcp(ip4_port.unwrap())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) - .with(Protocol::P2p(Multihash::from(peer1))), - ) - }, - network_interface::Addr::V6(record) => { - if record.ip.segments()[0] == 0xfe80 || ip6_port.is_none() { - continue; - } - - let config = config.with_websocket(Default::default()).build(); - - ( - Litep2p::new(config).unwrap(), - Multiaddr::empty() - .with(Protocol::Ip6(record.ip)) - .with(Protocol::Tcp(ip6_port.unwrap())) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) - .with(Protocol::P2p(Multihash::from(peer1))), - ) - }, - }; - - litep2p.dial_address(dial_address).await.unwrap(); - match litep2p.next_event().await { - Some(Litep2pEvent::ConnectionEstablished { .. }) => {}, - event => panic!("invalid event: {event:?}"), - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config1, _ping_event_stream1) = PingConfig::default(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_websocket(WebSocketConfig { + listen_addresses: vec![ + "/ip4/0.0.0.0/tcp/0/ws".parse().unwrap(), + "/ip6/::/tcp/0/ws".parse().unwrap(), + ], + ..Default::default() + }) + .with_libp2p_ping(ping_config1) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + + for address in litep2p1.listen_addresses() { + tracing::info!("address: {address:?}"); + } + + let listen_address = litep2p1.listen_addresses().collect::>(); + + let ip4_port = listen_address.iter().find_map(|address| { + let mut iter = address.iter(); + match iter.next() { + Some(Protocol::Ip4(_)) => match iter.next() { + Some(Protocol::Tcp(port)) => Some(port), + _ => panic!("invalid protocol"), + }, + _ => None, + } + }); + let ip6_port = listen_address.iter().find_map(|address| { + let mut iter = address.iter(); + match iter.next() { + Some(Protocol::Ip6(_)) => match iter.next() { + Some(Protocol::Tcp(port)) => Some(port), + _ => panic!("invalid protocol"), + }, + _ => None, + } + }); + + tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); + + let network_interfaces = NetworkInterface::show().unwrap(); + for iface in network_interfaces.iter() { + for address in &iface.addr { + let (ping_config2, _ping_event_stream2) = PingConfig::default(); + let config = ConfigBuilder::new().with_libp2p_ping(ping_config2); + + let (mut litep2p, dial_address) = match address { + network_interface::Addr::V4(record) => { + if ip4_port.is_none() { + continue; + } + + let config = config + .with_websocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }) + .build(); + + ( + Litep2p::new(config).unwrap(), + Multiaddr::empty() + .with(Protocol::Ip4(record.ip)) + .with(Protocol::Tcp(ip4_port.unwrap())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) + .with(Protocol::P2p(Multihash::from(peer1))), + ) + } + network_interface::Addr::V6(record) => { + if record.ip.segments()[0] == 0xfe80 || ip6_port.is_none() { + continue; + } + + let config = config.with_websocket(Default::default()).build(); + + ( + Litep2p::new(config).unwrap(), + Multiaddr::empty() + .with(Protocol::Ip6(record.ip)) + .with(Protocol::Tcp(ip6_port.unwrap())) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) + .with(Protocol::P2p(Multihash::from(peer1))), + ) + } + }; + + litep2p.dial_address(dial_address).await.unwrap(); + match litep2p.next_event().await { + Some(Litep2pEvent::ConnectionEstablished { .. }) => {} + event => panic!("invalid event: {event:?}"), + } + } + } } diff --git a/tests/connection/protocol_dial_invalid_address.rs b/tests/connection/protocol_dial_invalid_address.rs index 2e1c453d..e836ce33 100644 --- a/tests/connection/protocol_dial_invalid_address.rs +++ b/tests/connection/protocol_dial_invalid_address.rs @@ -19,13 +19,13 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - codec::ProtocolCodec, - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::{TransportEvent, TransportService, UserProtocol}, - transport::tcp::config::Config as TcpConfig, - types::protocol::ProtocolName, - Litep2p, PeerId, + codec::ProtocolCodec, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::{TransportEvent, TransportService, UserProtocol}, + transport::tcp::config::Config as TcpConfig, + types::protocol::ProtocolName, + Litep2p, PeerId, }; use futures::StreamExt; @@ -35,110 +35,118 @@ use tokio::sync::oneshot; #[derive(Debug)] struct CustomProtocol { - dial_address: Multiaddr, - protocol: ProtocolName, - codec: ProtocolCodec, - tx: oneshot::Sender<()>, + dial_address: Multiaddr, + protocol: ProtocolName, + codec: ProtocolCodec, + tx: oneshot::Sender<()>, } impl CustomProtocol { - pub fn new(dial_address: Multiaddr) -> (Self, oneshot::Receiver<()>) { - let (tx, rx) = oneshot::channel(); - - ( - Self { - dial_address, - protocol: ProtocolName::from("/custom-protocol/1"), - codec: ProtocolCodec::UnsignedVarint(None), - tx, - }, - rx, - ) - } + pub fn new(dial_address: Multiaddr) -> (Self, oneshot::Receiver<()>) { + let (tx, rx) = oneshot::channel(); + + ( + Self { + dial_address, + protocol: ProtocolName::from("/custom-protocol/1"), + codec: ProtocolCodec::UnsignedVarint(None), + tx, + }, + rx, + ) + } } #[async_trait::async_trait] impl UserProtocol for CustomProtocol { - fn protocol(&self) -> ProtocolName { - self.protocol.clone() - } - - fn codec(&self) -> ProtocolCodec { - self.codec.clone() - } - - async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { - if service.dial_address(self.dial_address.clone()).is_err() { - self.tx.send(()).unwrap(); - return Ok(()); - } - - loop { - while let Some(event) = service.next().await { - if let TransportEvent::DialFailure { .. } = event { - self.tx.send(()).unwrap(); - return Ok(()); - } - } - } - } + fn protocol(&self) -> ProtocolName { + self.protocol.clone() + } + + fn codec(&self) -> ProtocolCodec { + self.codec.clone() + } + + async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { + if service.dial_address(self.dial_address.clone()).is_err() { + self.tx.send(()).unwrap(); + return Ok(()); + } + + loop { + while let Some(event) = service.next().await { + if let TransportEvent::DialFailure { .. } = event { + self.tx.send(()).unwrap(); + return Ok(()); + } + } + } + } } #[tokio::test] async fn protocol_dial_invalid_dns_address() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - let address = Multiaddr::empty() - .with(Protocol::Dns(std::borrow::Cow::Owned( - "address.that.doesnt.exist.hopefully.pls".to_string(), - ))) - .with(Protocol::Tcp(8888)) - .with(Protocol::P2p(Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap())); - - let (custom_protocol, rx) = CustomProtocol::new(address); - let custom_protocol = Box::new(custom_protocol); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { ..Default::default() }) - .with_user_protocol(custom_protocol) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p1.next_event().await; - } - }); - - let _ = rx.await.unwrap(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + let address = Multiaddr::empty() + .with(Protocol::Dns(std::borrow::Cow::Owned( + "address.that.doesnt.exist.hopefully.pls".to_string(), + ))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(), + )); + + let (custom_protocol, rx) = CustomProtocol::new(address); + let custom_protocol = Box::new(custom_protocol); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + ..Default::default() + }) + .with_user_protocol(custom_protocol) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p1.next_event().await; + } + }); + + let _ = rx.await.unwrap(); } #[tokio::test] async fn protocol_dial_peer_id_missing() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - let address = Multiaddr::empty() - .with(Protocol::Dns(std::borrow::Cow::Owned("google.com".to_string()))) - .with(Protocol::Tcp(8888)); - - let (custom_protocol, rx) = CustomProtocol::new(address); - let custom_protocol = Box::new(custom_protocol); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { ..Default::default() }) - .with_user_protocol(custom_protocol) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - tokio::spawn(async move { - loop { - let _ = litep2p1.next_event().await; - } - }); - - let _ = rx.await.unwrap(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + let address = Multiaddr::empty() + .with(Protocol::Dns(std::borrow::Cow::Owned( + "google.com".to_string(), + ))) + .with(Protocol::Tcp(8888)); + + let (custom_protocol, rx) = CustomProtocol::new(address); + let custom_protocol = Box::new(custom_protocol); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + ..Default::default() + }) + .with_user_protocol(custom_protocol) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + tokio::spawn(async move { + loop { + let _ = litep2p1.next_event().await; + } + }); + + let _ = rx.await.unwrap(); } diff --git a/tests/custom_executor.rs b/tests/custom_executor.rs index 877b01d7..171219c5 100644 --- a/tests/custom_executor.rs +++ b/tests/custom_executor.rs @@ -19,20 +19,20 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - executor::Executor, - protocol::{ - notification::{ - Config as NotificationConfig, Direction, NotificationEvent, ValidationResult, - }, - request_response::{ - ConfigBuilder as RequestResponseConfigBuilder, DialOptions, RequestResponseEvent, - }, - }, - transport::tcp::config::Config as TcpConfig, - types::protocol::ProtocolName, - Litep2p, Litep2pEvent, + config::ConfigBuilder, + crypto::ed25519::Keypair, + executor::Executor, + protocol::{ + notification::{ + Config as NotificationConfig, Direction, NotificationEvent, ValidationResult, + }, + request_response::{ + ConfigBuilder as RequestResponseConfigBuilder, DialOptions, RequestResponseEvent, + }, + }, + transport::tcp::config::Config as TcpConfig, + types::protocol::ProtocolName, + Litep2p, Litep2pEvent, }; use bytes::BytesMut; @@ -42,235 +42,246 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{future::Future, pin::Pin, sync::Arc}; struct TaskExecutor { - rx: Receiver + Send>>>, - futures: FuturesUnordered>, + rx: Receiver + Send>>>, + futures: FuturesUnordered>, } impl TaskExecutor { - pub fn new() -> (Self, Sender + Send>>>) { - let (tx, rx) = channel(64); + pub fn new() -> (Self, Sender + Send>>>) { + let (tx, rx) = channel(64); - (Self { rx, futures: FuturesUnordered::new() }, tx) - } + ( + Self { + rx, + futures: FuturesUnordered::new(), + }, + tx, + ) + } - async fn next(&mut self) { - tokio::select! { - future = self.rx.recv() => { - self.futures.push(future.unwrap()); - } - _ = self.futures.next(), if !self.futures.is_empty() => {} - } - } + async fn next(&mut self) { + tokio::select! { + future = self.rx.recv() => { + self.futures.push(future.unwrap()); + } + _ = self.futures.next(), if !self.futures.is_empty() => {} + } + } } struct TaskExecutorHandle { - tx: Sender + Send>>>, + tx: Sender + Send>>>, } impl Executor for TaskExecutorHandle { - fn run(&self, future: Pin + Send>>) { - let _ = self.tx.try_send(future); - } + fn run(&self, future: Pin + Send>>) { + let _ = self.tx.try_send(future); + } - fn run_with_name(&self, _: &'static str, future: Pin + Send>>) { - let _ = self.tx.try_send(future); - } + fn run_with_name(&self, _: &'static str, future: Pin + Send>>) { + let _ = self.tx.try_send(future); + } } #[tokio::test] async fn custom_executor() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (mut executor, sender) = TaskExecutor::new(); + let (mut executor, sender) = TaskExecutor::new(); - tokio::spawn(async move { - loop { - executor.next().await - } - }); + tokio::spawn(async move { + loop { + executor.next().await + } + }); - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (req_resp_config1, mut req_resp_handle1) = - RequestResponseConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (req_resp_config1, mut req_resp_handle1) = + RequestResponseConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); - let handle = TaskExecutorHandle { tx: sender.clone() }; - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1) - .with_request_response_protocol(req_resp_config1) - .with_executor(Arc::new(handle)) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .build(); + let handle = TaskExecutorHandle { tx: sender.clone() }; + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1) + .with_request_response_protocol(req_resp_config1) + .with_executor(Arc::new(handle)) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .build(); - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let (req_resp_config2, mut req_resp_handle2) = - RequestResponseConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let (req_resp_config2, mut req_resp_handle2) = + RequestResponseConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); - let handle = TaskExecutorHandle { tx: sender }; - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2) - .with_request_response_protocol(req_resp_config2) - .with_executor(Arc::new(handle)) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .build(); + let handle = TaskExecutorHandle { tx: sender }; + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2) + .with_request_response_protocol(req_resp_config2) + .with_executor(Arc::new(handle)) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); - // wait until peers have connected and spawn the litep2p objects in the background - let address = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.dial_address(address).await.unwrap(); + // wait until peers have connected and spawn the litep2p objects in the background + let address = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.dial_address(address).await.unwrap(); - let mut litep2p1_connected = false; - let mut litep2p2_connected = false; + let mut litep2p1_connected = false; + let mut litep2p2_connected = false; - loop { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p1_connected = true; - } - _ => {}, - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p2_connected = true; - } - _ => {}, - } - } + loop { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p1_connected = true; + } + _ => {}, + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p2_connected = true; + } + _ => {}, + } + } - if litep2p1_connected && litep2p2_connected { - tokio::time::sleep(std::time::Duration::from_millis(200)).await; - break; - } - } - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); + if litep2p1_connected && litep2p2_connected { + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + break; + } + } + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); - // verify that the request-response protocol works as well - req_resp_handle1 - .send_request(peer2, vec![1, 2, 3, 4], DialOptions::Reject) - .await - .unwrap(); + // verify that the request-response protocol works as well + req_resp_handle1 + .send_request(peer2, vec![1, 2, 3, 4], DialOptions::Reject) + .await + .unwrap(); - match req_resp_handle2.next().await.unwrap() { - RequestResponseEvent::RequestReceived { peer, request_id, request, .. } => { - assert_eq!(peer, peer1); - assert_eq!(request, vec![1, 2, 3, 4]); - req_resp_handle2.send_response(request_id, vec![1, 3, 3, 7]); - }, - event => panic!("unexpected event: {event:?}"), - } + match req_resp_handle2.next().await.unwrap() { + RequestResponseEvent::RequestReceived { + peer, + request_id, + request, + .. + } => { + assert_eq!(peer, peer1); + assert_eq!(request, vec![1, 2, 3, 4]); + req_resp_handle2.send_response(request_id, vec![1, 3, 3, 7]); + } + event => panic!("unexpected event: {event:?}"), + } - match req_resp_handle1.next().await.unwrap() { - RequestResponseEvent::ResponseReceived { peer, response, .. } => { - assert_eq!(peer, peer2); - assert_eq!(response, vec![1, 3, 3, 7]); - }, - event => panic!("unexpected event: {event:?}"), - } + match req_resp_handle1.next().await.unwrap() { + RequestResponseEvent::ResponseReceived { peer, response, .. } => { + assert_eq!(peer, peer2); + assert_eq!(response, vec![1, 3, 3, 7]); + } + event => panic!("unexpected event: {event:?}"), + } } diff --git a/tests/protocol/identify.rs b/tests/protocol/identify.rs index f4ea3a9b..07f29d15 100644 --- a/tests/protocol/identify.rs +++ b/tests/protocol/identify.rs @@ -20,242 +20,261 @@ use futures::{FutureExt, StreamExt}; use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::libp2p::{ - identify::{Config, IdentifyEvent}, - ping::Config as PingConfig, - }, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - websocket::config::Config as WebSocketConfig, - }, - Litep2p, Litep2pEvent, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::libp2p::{ + identify::{Config, IdentifyEvent}, + ping::Config as PingConfig, + }, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + websocket::config::Config as WebSocketConfig, + }, + Litep2p, Litep2pEvent, }; enum Transport { - Quic(QuicConfig), - Tcp(TcpConfig), - WebSocket(WebSocketConfig), + Quic(QuicConfig), + Tcp(TcpConfig), + WebSocket(WebSocketConfig), } #[tokio::test] async fn identify_supported_tcp() { - identify_supported(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())).await + identify_supported( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn identify_supported_quic() { - identify_supported(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await + identify_supported( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await } #[tokio::test] async fn identify_supported_websocket() { - identify_supported( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await + identify_supported( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await } async fn identify_supported(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (identify_config1, mut identify_event_stream1) = - Config::new("/proto/1".to_string(), Some("agent v1".to_string()), Vec::new()); - let config_builder = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_identify(identify_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config_builder.with_tcp(config), - Transport::Quic(config) => config_builder.with_quic(config), - Transport::WebSocket(config) => config_builder.with_websocket(config), - } - .build(); - - let (identify_config2, mut identify_event_stream2) = - Config::new("/proto/2".to_string(), Some("agent v2".to_string()), Vec::new()); - let config_builder = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_identify(identify_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config_builder.with_tcp(config), - Transport::Quic(config) => config_builder.with_quic(config), - Transport::WebSocket(config) => config_builder.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let address1 = litep2p1.listen_addresses().next().unwrap().clone(); - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - - tracing::info!("listen address of peer1: {address1}"); - tracing::info!("listen address of peer2: {address2}"); - - litep2p1.dial_address(address2).await.unwrap(); - - let mut litep2p1_done = false; - let mut litep2p2_done = false; - - loop { - tokio::select! { - _event = litep2p1.next_event() => {} - _event = litep2p2.next_event() => {} - event = identify_event_stream1.next() => { - let IdentifyEvent::PeerIdentified { observed_address, protocol_version, user_agent, .. } = event.unwrap(); - tracing::info!("peer2 observed: {observed_address:?}"); - - assert_eq!(protocol_version, Some("/proto/2".to_string())); - assert_eq!(user_agent, Some("agent v2".to_string())); - - litep2p1_done = true; - - if litep2p1_done && litep2p2_done { - break - } - } - event = identify_event_stream2.next() => { - let IdentifyEvent::PeerIdentified { observed_address, protocol_version, user_agent, .. } = event.unwrap(); - tracing::info!("peer1 observed: {observed_address:?}"); - - assert_eq!(protocol_version, Some("/proto/1".to_string())); - assert_eq!(user_agent, Some("agent v1".to_string())); - - litep2p2_done = true; - - if litep2p1_done && litep2p2_done { - break - } - } - } - } - - let mut litep2p1_done = false; - let mut litep2p2_done = false; - - while !litep2p1_done || !litep2p2_done { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionClosed { .. } => { - litep2p1_done = true; - } - _ => {} - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionClosed { .. } => { - litep2p2_done = true; - } - _ => {} - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (identify_config1, mut identify_event_stream1) = Config::new( + "/proto/1".to_string(), + Some("agent v1".to_string()), + Vec::new(), + ); + let config_builder = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_identify(identify_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config_builder.with_tcp(config), + Transport::Quic(config) => config_builder.with_quic(config), + Transport::WebSocket(config) => config_builder.with_websocket(config), + } + .build(); + + let (identify_config2, mut identify_event_stream2) = Config::new( + "/proto/2".to_string(), + Some("agent v2".to_string()), + Vec::new(), + ); + let config_builder = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_identify(identify_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config_builder.with_tcp(config), + Transport::Quic(config) => config_builder.with_quic(config), + Transport::WebSocket(config) => config_builder.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let address1 = litep2p1.listen_addresses().next().unwrap().clone(); + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + + tracing::info!("listen address of peer1: {address1}"); + tracing::info!("listen address of peer2: {address2}"); + + litep2p1.dial_address(address2).await.unwrap(); + + let mut litep2p1_done = false; + let mut litep2p2_done = false; + + loop { + tokio::select! { + _event = litep2p1.next_event() => {} + _event = litep2p2.next_event() => {} + event = identify_event_stream1.next() => { + let IdentifyEvent::PeerIdentified { observed_address, protocol_version, user_agent, .. } = event.unwrap(); + tracing::info!("peer2 observed: {observed_address:?}"); + + assert_eq!(protocol_version, Some("/proto/2".to_string())); + assert_eq!(user_agent, Some("agent v2".to_string())); + + litep2p1_done = true; + + if litep2p1_done && litep2p2_done { + break + } + } + event = identify_event_stream2.next() => { + let IdentifyEvent::PeerIdentified { observed_address, protocol_version, user_agent, .. } = event.unwrap(); + tracing::info!("peer1 observed: {observed_address:?}"); + + assert_eq!(protocol_version, Some("/proto/1".to_string())); + assert_eq!(user_agent, Some("agent v1".to_string())); + + litep2p2_done = true; + + if litep2p1_done && litep2p2_done { + break + } + } + } + } + + let mut litep2p1_done = false; + let mut litep2p2_done = false; + + while !litep2p1_done || !litep2p2_done { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionClosed { .. } => { + litep2p1_done = true; + } + _ => {} + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionClosed { .. } => { + litep2p2_done = true; + } + _ => {} + } + } + } } #[tokio::test] async fn identify_not_supported_tcp() { - identify_not_supported(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())) - .await + identify_not_supported( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn identify_not_supported_quic() { - identify_not_supported(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await + identify_not_supported( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await } #[tokio::test] async fn identify_not_supported_websocket() { - identify_not_supported( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await + identify_not_supported( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await } async fn identify_not_supported(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (ping_config, _event_stream) = PingConfig::default(); - let config1 = match transport1 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_keypair(Keypair::generate()) - .with_libp2p_ping(ping_config) - .build(); - - let (identify_config2, mut identify_event_stream2) = - Config::new("litep2p".to_string(), None, Vec::new()); - let config_builder = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_libp2p_identify(identify_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config_builder.with_tcp(config), - Transport::Quic(config) => config_builder.with_quic(config), - Transport::WebSocket(config) => config_builder.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let address = litep2p2.listen_addresses().next().unwrap().clone(); - - litep2p1.dial_address(address).await.unwrap(); - - let mut litep2p1_done = false; - let mut litep2p2_done = false; - - while !litep2p1_done || !litep2p2_done { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - tracing::error!("litep2p1 connection established"); - litep2p1_done = true; - } - _ => {} - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - tracing::error!("litep2p2 connection established"); - litep2p2_done = true; - } - _ => {} - } - } - } - - let mut litep2p1_done = false; - let mut litep2p2_done = false; - - while !litep2p1_done || !litep2p2_done { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionClosed { .. } => { - tracing::error!("litep2p1 connection closed"); - litep2p1_done = true; - } - _ => {} - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionClosed { .. } => { - tracing::error!("litep2p2 connection closed"); - litep2p2_done = true; - } - _ => {} - } - } - } - - assert!(identify_event_stream2.next().now_or_never().is_none()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (ping_config, _event_stream) = PingConfig::default(); + let config1 = match transport1 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_keypair(Keypair::generate()) + .with_libp2p_ping(ping_config) + .build(); + + let (identify_config2, mut identify_event_stream2) = + Config::new("litep2p".to_string(), None, Vec::new()); + let config_builder = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_libp2p_identify(identify_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config_builder.with_tcp(config), + Transport::Quic(config) => config_builder.with_quic(config), + Transport::WebSocket(config) => config_builder.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let address = litep2p2.listen_addresses().next().unwrap().clone(); + + litep2p1.dial_address(address).await.unwrap(); + + let mut litep2p1_done = false; + let mut litep2p2_done = false; + + while !litep2p1_done || !litep2p2_done { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + tracing::error!("litep2p1 connection established"); + litep2p1_done = true; + } + _ => {} + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + tracing::error!("litep2p2 connection established"); + litep2p2_done = true; + } + _ => {} + } + } + } + + let mut litep2p1_done = false; + let mut litep2p2_done = false; + + while !litep2p1_done || !litep2p2_done { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionClosed { .. } => { + tracing::error!("litep2p1 connection closed"); + litep2p1_done = true; + } + _ => {} + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionClosed { .. } => { + tracing::error!("litep2p2 connection closed"); + litep2p2_done = true; + } + _ => {} + } + } + } + + assert!(identify_event_stream2.next().now_or_never().is_none()); } diff --git a/tests/protocol/kademlia.rs b/tests/protocol/kademlia.rs index dd8d099f..4d9ca922 100644 --- a/tests/protocol/kademlia.rs +++ b/tests/protocol/kademlia.rs @@ -22,100 +22,103 @@ use bytes::Bytes; use futures::StreamExt; use litep2p::{ - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::libp2p::kademlia::{ConfigBuilder as KademliaConfigBuilder, RecordKey}, - transport::tcp::config::Config as TcpConfig, - Litep2p, PeerId, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::libp2p::kademlia::{ConfigBuilder as KademliaConfigBuilder, RecordKey}, + transport::tcp::config::Config as TcpConfig, + Litep2p, PeerId, }; fn spawn_litep2p(port: u16) { - let (kad_config1, _kad_handle1) = KademliaConfigBuilder::new().build(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec![format!("/ip6/::1/tcp/{port}").parse().unwrap()], - ..Default::default() - }) - .with_libp2p_kademlia(kad_config1) - .build(); + let (kad_config1, _kad_handle1) = KademliaConfigBuilder::new().build(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec![format!("/ip6/::1/tcp/{port}").parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config1) + .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); - tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); + tokio::spawn(async move { while let Some(_) = litep2p1.next_event().await {} }); } #[tokio::test] #[ignore] async fn kademlia_supported() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (kad_config1, _kad_handle1) = KademliaConfigBuilder::new().build(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_kademlia(kad_config1) - .build(); + let (kad_config1, _kad_handle1) = KademliaConfigBuilder::new().build(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config1) + .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); - for port in 9000..9003 { - spawn_litep2p(port); - } + for port in 9000..9003 { + spawn_litep2p(port); + } - loop { - tokio::select! { - event = litep2p1.next_event() => { - tracing::info!("litep2p event received: {event:?}"); - } - // event = kad_handle1.next() => { - // tracing::info!("kademlia event received: {event:?}"); - // } - } - } + loop { + tokio::select! { + event = litep2p1.next_event() => { + tracing::info!("litep2p event received: {event:?}"); + } + // event = kad_handle1.next() => { + // tracing::info!("kademlia event received: {event:?}"); + // } + } + } } #[tokio::test] #[ignore] async fn put_value() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (kad_config1, mut kad_handle1) = KademliaConfigBuilder::new().build(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_libp2p_kademlia(kad_config1) - .build(); + let (kad_config1, mut kad_handle1) = KademliaConfigBuilder::new().build(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_libp2p_kademlia(kad_config1) + .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); - for i in 0..10 { - kad_handle1 - .add_known_peer(PeerId::random(), vec![format!("/ip6/::/tcp/{i}").parse().unwrap()]) - .await; - } + for i in 0..10 { + kad_handle1 + .add_known_peer( + PeerId::random(), + vec![format!("/ip6/::/tcp/{i}").parse().unwrap()], + ) + .await; + } - // let key = RecordKey::new(&Bytes::from(vec![1, 3, 3, 7])); - // kad_handle1.put_value(key, vec![1, 2, 3, 4]).await; + // let key = RecordKey::new(&Bytes::from(vec![1, 3, 3, 7])); + // kad_handle1.put_value(key, vec![1, 2, 3, 4]).await; - // loop { - // tokio::select! { - // event = litep2p1.next_event() => { - // tracing::info!("litep2p event received: {event:?}"); - // } - // event = kad_handle1.next() => { - // tracing::info!("kademlia event received: {event:?}"); - // } - // } - // } + // loop { + // tokio::select! { + // event = litep2p1.next_event() => { + // tracing::info!("litep2p event received: {event:?}"); + // } + // event = kad_handle1.next() => { + // tracing::info!("kademlia event received: {event:?}"); + // } + // } + // } } diff --git a/tests/protocol/notification.rs b/tests/protocol/notification.rs index 74974774..f40fc3f6 100644 --- a/tests/protocol/notification.rs +++ b/tests/protocol/notification.rs @@ -19,19 +19,19 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - config::ConfigBuilder as Litep2pConfigBuilder, - crypto::ed25519::Keypair, - error::Error, - protocol::notification::{ - Config as NotificationConfig, ConfigBuilder, Direction, NotificationError, - NotificationEvent, NotificationHandle, ValidationResult, - }, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - websocket::config::Config as WebSocketConfig, - }, - types::protocol::ProtocolName, - Litep2p, Litep2pEvent, PeerId, + config::ConfigBuilder as Litep2pConfigBuilder, + crypto::ed25519::Keypair, + error::Error, + protocol::notification::{ + Config as NotificationConfig, ConfigBuilder, Direction, NotificationError, + NotificationEvent, NotificationHandle, ValidationResult, + }, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + websocket::config::Config as WebSocketConfig, + }, + types::protocol::ProtocolName, + Litep2p, Litep2pEvent, PeerId, }; use bytes::BytesMut; @@ -40,4054 +40,4112 @@ use multiaddr::{Multiaddr, Protocol}; use multihash::Multihash; use std::{ - net::{Ipv4Addr, Ipv6Addr}, - task::Poll, - time::Duration, + net::{Ipv4Addr, Ipv6Addr}, + task::Poll, + time::Duration, }; enum Transport { - Tcp(TcpConfig), - Quic(QuicConfig), - WebSocket(WebSocketConfig), + Tcp(TcpConfig), + Quic(QuicConfig), + WebSocket(WebSocketConfig), } async fn connect_peers(litep2p1: &mut Litep2p, litep2p2: &mut Litep2p) { - let address = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.dial_address(address).await.unwrap(); - - let mut litep2p1_connected = false; - let mut litep2p2_connected = false; - - loop { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p1_connected = true; - } - _ => {}, - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p2_connected = true; - } - _ => {}, - } - } - - if litep2p1_connected && litep2p2_connected { - break; - } - } - - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let address = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.dial_address(address).await.unwrap(); + + let mut litep2p1_connected = false; + let mut litep2p2_connected = false; + + loop { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p1_connected = true; + } + _ => {}, + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p2_connected = true; + } + _ => {}, + } + } + + if litep2p1_connected && litep2p2_connected { + break; + } + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; } async fn make_default_litep2p(transport: Transport) -> (Litep2p, NotificationHandle) { - let (notif_config, handle) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config); - - let config = match transport { - Transport::Tcp(transport_config) => config.with_tcp(transport_config), - Transport::Quic(transport_config) => config.with_quic(transport_config), - Transport::WebSocket(transport_config) => config.with_websocket(transport_config), - } - .build(); - - (Litep2p::new(config).unwrap(), handle) + let (notif_config, handle) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config); + + let config = match transport { + Transport::Tcp(transport_config) => config.with_tcp(transport_config), + Transport::Quic(transport_config) => config.with_quic(transport_config), + Transport::WebSocket(transport_config) => config.with_websocket(transport_config), + } + .build(); + + (Litep2p::new(config).unwrap(), handle) } #[tokio::test] async fn open_substreams_tcp() { - open_substreams( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + open_substreams( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn open_substreams_quic() { - open_substreams(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + open_substreams( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn open_substreams_websocket() { - open_substreams( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + open_substreams( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn open_substreams(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); } #[tokio::test] async fn reject_substream_tcp() { - reject_substream( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + reject_substream( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn reject_substream_quic() { - reject_substream(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + reject_substream( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn reject_substream_websocket() { - reject_substream( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + reject_substream( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn reject_substream(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Reject); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Reject); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected, + } + ); } #[tokio::test] async fn notification_stream_closed_tcp() { - notification_stream_closed( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + notification_stream_closed( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn notification_stream_closed_quic() { - notification_stream_closed( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + notification_stream_closed( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn notification_stream_closed_websocket() { - notification_stream_closed( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + notification_stream_closed( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn notification_stream_closed(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); - - handle1.close_substream(peer2).await; - - match handle2.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer1), - _ => panic!("invalid event received"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); + + handle1.close_substream(peer2).await; + + match handle2.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer1), + _ => panic!("invalid event received"), + } } #[tokio::test] async fn reconnect_after_disconnect_tcp() { - reconnect_after_disconnect( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + reconnect_after_disconnect( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn reconnect_after_disconnect_quic() { - reconnect_after_disconnect( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + reconnect_after_disconnect( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn reconnect_after_disconnect_websocket() { - reconnect_after_disconnect( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + reconnect_after_disconnect( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn reconnect_after_disconnect(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - - // accept the inbound substreams - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // accept the inbound substreams - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - // close the substream - handle2.close_substream(peer1).await; - - match handle2.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer1), - _ => panic!("invalid event received"), - } - - match handle1.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer2), - _ => panic!("invalid event received"), - } - - // open the substream - handle2.open_substream(peer1).await.unwrap(); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // verify that both peers get the open event - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - // send notifications to verify that the connection works again - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + + // accept the inbound substreams + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // accept the inbound substreams + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + // close the substream + handle2.close_substream(peer1).await; + + match handle2.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer1), + _ => panic!("invalid event received"), + } + + match handle1.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer2), + _ => panic!("invalid event received"), + } + + // open the substream + handle2.open_substream(peer1).await.unwrap(); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // verify that both peers get the open event + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + // send notifications to verify that the connection works again + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); } #[tokio::test] async fn set_new_handshake_tcp() { - set_new_handshake( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + set_new_handshake( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn set_new_handshake_quic() { - set_new_handshake(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + set_new_handshake( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn set_new_handshake_websocket() { - set_new_handshake( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + set_new_handshake( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn set_new_handshake(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - - // accept the substreams - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // accept the substreams - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - // close the substream - handle2.close_substream(peer1).await; - - match handle2.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer1), - _ => panic!("invalid event received"), - } - - match handle1.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer2), - _ => panic!("invalid event received"), - } - - // set new handshakes and open the substream - handle1.set_handshake(vec![5, 5, 5, 5]); - handle2.set_handshake(vec![6, 6, 6, 6]); - handle2.open_substream(peer1).await.unwrap(); - - // accept the substreams - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![6, 6, 6, 6], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - // accept the substreams - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![5, 5, 5, 5], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // verify that both peers get the open event - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer1, - handshake: vec![5, 5, 5, 5], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer2, - handshake: vec![6, 6, 6, 6], - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + + // accept the substreams + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // accept the substreams + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + // close the substream + handle2.close_substream(peer1).await; + + match handle2.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer1), + _ => panic!("invalid event received"), + } + + match handle1.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => assert_eq!(peer, peer2), + _ => panic!("invalid event received"), + } + + // set new handshakes and open the substream + handle1.set_handshake(vec![5, 5, 5, 5]); + handle2.set_handshake(vec![6, 6, 6, 6]); + handle2.open_substream(peer1).await.unwrap(); + + // accept the substreams + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![6, 6, 6, 6], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + // accept the substreams + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![5, 5, 5, 5], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // verify that both peers get the open event + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer1, + handshake: vec![5, 5, 5, 5], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer2, + handshake: vec![6, 6, 6, 6], + } + ); } #[tokio::test] async fn both_nodes_open_substreams_tcp() { - both_nodes_open_substreams( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + both_nodes_open_substreams( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn both_nodes_open_substreams_quic() { - both_nodes_open_substreams( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + both_nodes_open_substreams( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn both_nodes_open_substreams_websocket() { - both_nodes_open_substreams( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + both_nodes_open_substreams( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn both_nodes_open_substreams(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // both nodes open a substream at the same time - handle1.open_substream(peer2).await.unwrap(); - handle2.open_substream(peer1).await.unwrap(); - - // accept the substreams - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - // accept the substreams - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // both nodes open a substream at the same time + handle1.open_substream(peer2).await.unwrap(); + handle2.open_substream(peer1).await.unwrap(); + + // accept the substreams + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + // accept the substreams + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); } #[tokio::test] #[cfg(debug_assertions)] async fn both_nodes_open_substream_one_rejects_substreams_tcp() { - both_nodes_open_substream_one_rejects_substreams( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + both_nodes_open_substream_one_rejects_substreams( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] #[cfg(debug_assertions)] async fn both_nodes_open_substream_one_rejects_substreams_quic() { - both_nodes_open_substream_one_rejects_substreams( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + both_nodes_open_substream_one_rejects_substreams( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] #[cfg(debug_assertions)] async fn both_nodes_open_substream_one_rejects_substreams_websocket() { - both_nodes_open_substream_one_rejects_substreams( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + both_nodes_open_substream_one_rejects_substreams( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn both_nodes_open_substream_one_rejects_substreams( - transport1: Transport, - transport2: Transport, + transport1: Transport, + transport2: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // both nodes open a substream at the same time - handle1.open_substream(peer2).await.unwrap(); - handle2.open_substream(peer1).await.unwrap(); - - // first peer accepts the substream - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - // the second peer rejects the substream - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Reject); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected - }, - ); - - assert!(tokio::time::timeout(Duration::from_secs(5), handle2.next()).await.is_err()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // both nodes open a substream at the same time + handle1.open_substream(peer2).await.unwrap(); + handle2.open_substream(peer1).await.unwrap(); + + // first peer accepts the substream + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + // the second peer rejects the substream + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Reject); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected + }, + ); + + assert!(tokio::time::timeout(Duration::from_secs(5), handle2.next()).await.is_err()); } #[tokio::test] async fn send_sync_notification_to_non_existent_peer_tcp() { - send_sync_notification_to_non_existent_peer(Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - })) - .await + send_sync_notification_to_non_existent_peer(Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + })) + .await } #[tokio::test] async fn send_sync_notification_to_non_existent_peer_quic() { - send_sync_notification_to_non_existent_peer(Transport::Quic(Default::default())).await; + send_sync_notification_to_non_existent_peer(Transport::Quic(Default::default())).await; } #[tokio::test] async fn send_sync_notification_to_non_existent_peer_websocket() { - send_sync_notification_to_non_existent_peer(Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - })) - .await; + send_sync_notification_to_non_existent_peer(Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + })) + .await; } async fn send_sync_notification_to_non_existent_peer(transport1: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - } - } - }); - - handle1.send_sync_notification(PeerId::random(), vec![1, 3, 3, 7]).unwrap(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + } + } + }); + + handle1.send_sync_notification(PeerId::random(), vec![1, 3, 3, 7]).unwrap(); } #[tokio::test] async fn send_async_notification_to_non_existent_peer_tcp() { - send_async_notification_to_non_existent_peer(Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - })) - .await + send_async_notification_to_non_existent_peer(Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + })) + .await } #[tokio::test] async fn send_async_notification_to_non_existent_peer_quic() { - send_async_notification_to_non_existent_peer(Transport::Quic(Default::default())).await; + send_async_notification_to_non_existent_peer(Transport::Quic(Default::default())).await; } #[tokio::test] async fn send_async_notification_to_non_existent_peer_websocket() { - send_async_notification_to_non_existent_peer(Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - })) - .await; + send_async_notification_to_non_existent_peer(Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + })) + .await; } async fn send_async_notification_to_non_existent_peer(transport1: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - } - } - }); - - assert!(handle1 - .send_async_notification(PeerId::random(), vec![1, 3, 3, 7]) - .await - .is_err()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + } + } + }); + + assert!(handle1 + .send_async_notification(PeerId::random(), vec![1, 3, 3, 7]) + .await + .is_err()); } #[tokio::test] async fn try_to_connect_to_non_existent_peer_tcp() { - try_to_connect_to_non_existent_peer(Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - })) - .await + try_to_connect_to_non_existent_peer(Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + })) + .await } #[tokio::test] async fn try_to_connect_to_non_existent_peer_quic() { - try_to_connect_to_non_existent_peer(Transport::Quic(Default::default())).await; + try_to_connect_to_non_existent_peer(Transport::Quic(Default::default())).await; } #[tokio::test] async fn try_to_connect_to_non_existent_peer_websocket() { - try_to_connect_to_non_existent_peer(Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - })) - .await; + try_to_connect_to_non_existent_peer(Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + })) + .await; } async fn try_to_connect_to_non_existent_peer(transport1: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - } - } - }); - - let peer = PeerId::random(); - handle1.open_substream(peer).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer, - error: NotificationError::DialFailure, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + } + } + }); + + let peer = PeerId::random(); + handle1.open_substream(peer).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer, + error: NotificationError::DialFailure, + } + ); } #[tokio::test] async fn try_to_disconnect_non_existent_peer_tcp() { - try_to_disconnect_non_existent_peer(Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - })) - .await + try_to_disconnect_non_existent_peer(Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + })) + .await } #[tokio::test] async fn try_to_disconnect_non_existent_peer_quic() { - try_to_disconnect_non_existent_peer(Transport::Quic(Default::default())).await; + try_to_disconnect_non_existent_peer(Transport::Quic(Default::default())).await; } #[tokio::test] async fn try_to_disconnect_non_existent_peer_websocket() { - try_to_disconnect_non_existent_peer(Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - })) - .await; + try_to_disconnect_non_existent_peer(Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + })) + .await; } async fn try_to_disconnect_non_existent_peer(transport1: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - } - } - }); - - handle1.close_substream(PeerId::random()).await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + } + } + }); + + handle1.close_substream(PeerId::random()).await; } #[tokio::test] async fn try_to_reopen_substream_tcp() { - try_to_reopen_substream( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + try_to_reopen_substream( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn try_to_reopen_substream_quic() { - try_to_reopen_substream( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + try_to_reopen_substream( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn try_to_reopen_substream_websocket() { - try_to_reopen_substream( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + try_to_reopen_substream( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn try_to_reopen_substream(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - // open substream for `peer2` and accept it - match handle1.open_substream(peer2).await { - Err(Error::PeerAlreadyExists(peer)) => assert_eq!(peer, peer2), - result => panic!("invalid event received: {result:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + // open substream for `peer2` and accept it + match handle1.open_substream(peer2).await { + Err(Error::PeerAlreadyExists(peer)) => assert_eq!(peer, peer2), + result => panic!("invalid event received: {result:?}"), + } } #[tokio::test] async fn substream_validation_timeout_tcp() { - substream_validation_timeout( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + substream_validation_timeout( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn substream_validation_timeout_quic() { - substream_validation_timeout( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + substream_validation_timeout( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn substream_validation_timeout_websocket() { - substream_validation_timeout( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + substream_validation_timeout( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn substream_validation_timeout(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - - // don't reject the substream but let it timeout - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + + // don't reject the substream but let it timeout + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected, + } + ); } #[tokio::test] async fn unsupported_protocol_tcp() { - unsupported_protocol( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + unsupported_protocol( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn unsupported_protocol_quic() { - unsupported_protocol(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + unsupported_protocol( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn unsupported_protocol_websocket() { - unsupported_protocol( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + unsupported_protocol( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn unsupported_protocol(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, _handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, _handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected + } + ); } #[tokio::test] async fn dialer_fallback_protocol_works_tcp() { - dialer_fallback_protocol_works( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dialer_fallback_protocol_works( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn dialer_fallback_protocol_works_quic() { - dialer_fallback_protocol_works( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + dialer_fallback_protocol_works( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dialer_fallback_protocol_works_websocket() { - dialer_fallback_protocol_works( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dialer_fallback_protocol_works( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn dialer_fallback_protocol_works(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/2")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_fallback_names(vec![ProtocolName::from("/notif/1")]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/2"), - fallback: Some(ProtocolName::from("/notif/1")), - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Inbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/2"), - fallback: Some(ProtocolName::from("/notif/1")), - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/2")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_fallback_names(vec![ProtocolName::from("/notif/1")]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/2"), + fallback: Some(ProtocolName::from("/notif/1")), + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Inbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/2"), + fallback: Some(ProtocolName::from("/notif/1")), + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); } #[tokio::test] async fn listener_fallback_protocol_works_tcp() { - listener_fallback_protocol_works( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + listener_fallback_protocol_works( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn listener_fallback_protocol_works_quic() { - listener_fallback_protocol_works( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + listener_fallback_protocol_works( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn listener_fallback_protocol_works_websocket() { - listener_fallback_protocol_works( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + listener_fallback_protocol_works( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn listener_fallback_protocol_works(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_fallback_names(vec![ProtocolName::from("/notif/1")]) - .build(); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/2"), - fallback: Some(ProtocolName::from("/notif/1")), - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/2"), - fallback: Some(ProtocolName::from("/notif/1")), - direction: Direction::Inbound, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Outbound, - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_fallback_names(vec![ProtocolName::from("/notif/1")]) + .build(); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/2"), + fallback: Some(ProtocolName::from("/notif/1")), + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/2"), + fallback: Some(ProtocolName::from("/notif/1")), + direction: Direction::Inbound, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Outbound, + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); } #[tokio::test] async fn enable_auto_accept_tcp() { - enable_auto_accept( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + enable_auto_accept( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn enable_auto_accept_quic() { - enable_auto_accept(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + enable_auto_accept( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn enable_auto_accept_websocket() { - enable_auto_accept( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + enable_auto_accept( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn enable_auto_accept(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - true, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Outbound, - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + true, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Outbound, + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); } #[tokio::test] async fn send_using_notification_sink_tcp() { - send_using_notification_sink( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + send_using_notification_sink( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn send_using_notification_sink_quic() { - send_using_notification_sink( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + send_using_notification_sink( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn send_using_notification_sink_websocket() { - send_using_notification_sink( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + send_using_notification_sink( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn send_using_notification_sink(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Outbound, - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - let sink1 = handle1.notification_sink(peer2).unwrap(); - let sink2 = handle2.notification_sink(peer1).unwrap(); - - sink1.send_sync_notification(vec![1, 3, 3, 7]).unwrap(); - sink2.send_sync_notification(vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); - - // close the substream to `peer1` and try to send notification using `sink1` - handle2.close_substream(peer1).await; - - // allow `peer1` to detect that the substream has been closed - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - - assert_eq!( - sink1.send_sync_notification(vec![1, 3, 3, 7]), - Err(NotificationError::NoConnection), - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Outbound, + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + let sink1 = handle1.notification_sink(peer2).unwrap(); + let sink2 = handle2.notification_sink(peer1).unwrap(); + + sink1.send_sync_notification(vec![1, 3, 3, 7]).unwrap(); + sink2.send_sync_notification(vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); + + // close the substream to `peer1` and try to send notification using `sink1` + handle2.close_substream(peer1).await; + + // allow `peer1` to detect that the substream has been closed + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + assert_eq!( + sink1.send_sync_notification(vec![1, 3, 3, 7]), + Err(NotificationError::NoConnection), + ); } #[tokio::test] async fn dial_peer_when_opening_substream_tcp() { - dial_peer_when_opening_substream( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + dial_peer_when_opening_substream( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn dial_peer_when_opening_substream_quic() { - dial_peer_when_opening_substream( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + dial_peer_when_opening_substream( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dial_peer_when_opening_substream_websocket() { - dial_peer_when_opening_substream( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dial_peer_when_opening_substream( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn dial_peer_when_opening_substream(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config2, mut handle2) = NotificationConfig::new( - ProtocolName::from("/notif/1"), - 1024usize, - vec![1, 2, 3, 4], - Vec::new(), - false, - 64, - 64, - true, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - let address = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.add_known_address(peer2, std::iter::once(address)); - - // add `peer2` known address for `peer1` and spawn the litep2p objects in the background - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Outbound, - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - let sink1 = handle1.notification_sink(peer2).unwrap(); - let sink2 = handle2.notification_sink(peer1).unwrap(); - - sink1.send_sync_notification(vec![1, 3, 3, 7]).unwrap(); - sink2.send_sync_notification(vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); - - // close the substream to `peer1` and try to send notification using `sink1` - handle2.close_substream(peer1).await; - - // allow `peer1` to detect that the substream has been closed - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - - assert_eq!( - sink1.send_sync_notification(vec![1, 3, 3, 7]), - Err(NotificationError::NoConnection), - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config2, mut handle2) = NotificationConfig::new( + ProtocolName::from("/notif/1"), + 1024usize, + vec![1, 2, 3, 4], + Vec::new(), + false, + 64, + 64, + true, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + let address = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.add_known_address(peer2, std::iter::once(address)); + + // add `peer2` known address for `peer1` and spawn the litep2p objects in the background + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Outbound, + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + let sink1 = handle1.notification_sink(peer2).unwrap(); + let sink2 = handle2.notification_sink(peer1).unwrap(); + + sink1.send_sync_notification(vec![1, 3, 3, 7]).unwrap(); + sink2.send_sync_notification(vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); + + // close the substream to `peer1` and try to send notification using `sink1` + handle2.close_substream(peer1).await; + + // allow `peer1` to detect that the substream has been closed + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + assert_eq!( + sink1.send_sync_notification(vec![1, 3, 3, 7]), + Err(NotificationError::NoConnection), + ); } #[tokio::test] async fn open_and_close_batched_tcp() { - open_and_close_batched( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + open_and_close_batched( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn open_and_close_batched_quic() { - open_and_close_batched( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + open_and_close_batched( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn open_and_close_batched_websocket() { - open_and_close_batched( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + open_and_close_batched( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn open_and_close_batched( - transport1: Transport, - transport2: Transport, - transport3: Transport, + transport1: Transport, + transport2: Transport, + transport3: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut litep2p1, mut handle1) = make_default_litep2p(transport1).await; - let (mut litep2p2, mut handle2) = make_default_litep2p(transport2).await; - let (mut litep2p3, mut handle3) = make_default_litep2p(transport3).await; - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - let peer3 = *litep2p3.local_peer_id(); - - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - let address3 = litep2p3.listen_addresses().next().unwrap().clone(); - litep2p1.add_known_address(peer2, std::iter::once(address2)); - litep2p1.add_known_address(peer3, std::iter::once(address3)); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - _ = litep2p3.next_event() => {}, - } - } - }); - - // open substreams to `peer2` and `peer3` - handle1.open_substream_batch(vec![peer3, peer2].into_iter()).await.unwrap(); - - // accept for `peer2` - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // accept for `peer3` - assert_eq!( - handle3.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle3.send_validation_result(peer1, ValidationResult::Accept); - - // accept inbound substream for `peer2` and `peer3` - let mut peer2_validated = false; - let mut peer3_validated = false; - let mut peer2_opened = false; - let mut peer3_opened = false; - - while !peer2_validated || !peer3_validated || !peer2_opened || !peer3_opened { - match handle1.next().await.unwrap() { - NotificationEvent::ValidateSubstream { protocol, fallback, peer, handshake } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); - assert_eq!(handshake, vec![1, 2, 3, 4]); - assert_eq!(fallback, None); - - if peer == peer2 && !peer2_validated { - peer2_validated = true; - } else if peer == peer3 && !peer3_validated { - peer3_validated = true; - } else { - panic!("received an event from an unexpected peer"); - } - - handle1.send_validation_result(peer, ValidationResult::Accept); - }, - NotificationEvent::NotificationStreamOpened { peer, .. } => { - if peer == peer2 && !peer2_opened { - peer2_opened = true; - } else if peer == peer3 && !peer3_opened { - peer3_opened = true; - } else { - panic!("received an event from an unexpected peer"); - } - }, - _ => panic!("invalid event"), - } - } - - // verify the substream is opened for `peer2` and `peer3` - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle3.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - - // close substreams to `peer2` and `peer3` - handle1.close_substream_batch(vec![peer2, peer3].into_iter()).await; - - // verify the substream is closed for `peer2` and `peer3` - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamClosed { peer: peer1 } - ); - assert_eq!( - handle3.next().await.unwrap(), - NotificationEvent::NotificationStreamClosed { peer: peer1 } - ); - - // verify `peer1` receives close events for both peers - let mut peer2_closed = false; - let mut peer3_closed = false; - - while !peer2_closed || !peer3_closed { - match handle1.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => { - if peer == peer2 && !peer2_closed { - peer2_closed = true; - } else if peer == peer3 && !peer3_closed { - peer3_closed = true; - } else { - panic!("received an event from an unexpected peer"); - } - }, - _ => panic!("invalid event"), - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut litep2p1, mut handle1) = make_default_litep2p(transport1).await; + let (mut litep2p2, mut handle2) = make_default_litep2p(transport2).await; + let (mut litep2p3, mut handle3) = make_default_litep2p(transport3).await; + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + let peer3 = *litep2p3.local_peer_id(); + + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + let address3 = litep2p3.listen_addresses().next().unwrap().clone(); + litep2p1.add_known_address(peer2, std::iter::once(address2)); + litep2p1.add_known_address(peer3, std::iter::once(address3)); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + _ = litep2p3.next_event() => {}, + } + } + }); + + // open substreams to `peer2` and `peer3` + handle1.open_substream_batch(vec![peer3, peer2].into_iter()).await.unwrap(); + + // accept for `peer2` + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // accept for `peer3` + assert_eq!( + handle3.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle3.send_validation_result(peer1, ValidationResult::Accept); + + // accept inbound substream for `peer2` and `peer3` + let mut peer2_validated = false; + let mut peer3_validated = false; + let mut peer2_opened = false; + let mut peer3_opened = false; + + while !peer2_validated || !peer3_validated || !peer2_opened || !peer3_opened { + match handle1.next().await.unwrap() { + NotificationEvent::ValidateSubstream { + protocol, + fallback, + peer, + handshake, + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(handshake, vec![1, 2, 3, 4]); + assert_eq!(fallback, None); + + if peer == peer2 && !peer2_validated { + peer2_validated = true; + } else if peer == peer3 && !peer3_validated { + peer3_validated = true; + } else { + panic!("received an event from an unexpected peer"); + } + + handle1.send_validation_result(peer, ValidationResult::Accept); + } + NotificationEvent::NotificationStreamOpened { peer, .. } => { + if peer == peer2 && !peer2_opened { + peer2_opened = true; + } else if peer == peer3 && !peer3_opened { + peer3_opened = true; + } else { + panic!("received an event from an unexpected peer"); + } + } + _ => panic!("invalid event"), + } + } + + // verify the substream is opened for `peer2` and `peer3` + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle3.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + + // close substreams to `peer2` and `peer3` + handle1.close_substream_batch(vec![peer2, peer3].into_iter()).await; + + // verify the substream is closed for `peer2` and `peer3` + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamClosed { peer: peer1 } + ); + assert_eq!( + handle3.next().await.unwrap(), + NotificationEvent::NotificationStreamClosed { peer: peer1 } + ); + + // verify `peer1` receives close events for both peers + let mut peer2_closed = false; + let mut peer3_closed = false; + + while !peer2_closed || !peer3_closed { + match handle1.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => { + if peer == peer2 && !peer2_closed { + peer2_closed = true; + } else if peer == peer3 && !peer3_closed { + peer3_closed = true; + } else { + panic!("received an event from an unexpected peer"); + } + } + _ => panic!("invalid event"), + } + } } #[tokio::test] async fn open_and_close_batched_duplicate_peer_tcp() { - open_and_close_batched_duplicate_peer( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + open_and_close_batched_duplicate_peer( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn open_and_close_batched_duplicate_peer_quic() { - open_and_close_batched_duplicate_peer( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + open_and_close_batched_duplicate_peer( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn open_and_close_batched_duplicate_peer_websocket() { - open_and_close_batched_duplicate_peer( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + open_and_close_batched_duplicate_peer( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn open_and_close_batched_duplicate_peer( - transport1: Transport, - transport2: Transport, - transport3: Transport, + transport1: Transport, + transport2: Transport, + transport3: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut litep2p1, mut handle1) = make_default_litep2p(transport1).await; - let (mut litep2p2, mut handle2) = make_default_litep2p(transport2).await; - let (mut litep2p3, mut handle3) = make_default_litep2p(transport3).await; - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - let peer3 = *litep2p3.local_peer_id(); - - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - let address3 = litep2p3.listen_addresses().next().unwrap().clone(); - litep2p1.add_known_address(peer2, std::iter::once(address2)); - litep2p1.add_known_address(peer3, std::iter::once(address3)); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - _ = litep2p3.next_event() => {}, - } - } - }); - - // open substream to `peer2`. - handle1.open_substream_batch(vec![peer2].into_iter()).await.unwrap(); - - // accept for `peer2` - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // accept inbound substream for `peer2` - let mut peer2_validated = false; - let mut peer2_opened = false; - - while !peer2_validated || !peer2_opened { - match handle1.next().await.unwrap() { - NotificationEvent::ValidateSubstream { protocol, fallback, peer, handshake } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); - assert_eq!(handshake, vec![1, 2, 3, 4]); - assert_eq!(fallback, None); - assert_eq!(peer, peer2); - - if !peer2_validated { - peer2_validated = true; - } else { - panic!("received an event from an unexpected peer"); - } - - handle1.send_validation_result(peer, ValidationResult::Accept); - }, - NotificationEvent::NotificationStreamOpened { peer, .. } => { - assert_eq!(peer, peer2); - - if !peer2_opened { - peer2_opened = true; - } else { - panic!("received an event from an unexpected peer"); - } - }, - _ => panic!("invalid event"), - } - } - - // verify the substream is opened for `peer2` - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - - // batch another substream open command but this time include `peer2` for which - // a connection is already open - match handle1.open_substream_batch(vec![peer2, peer3].into_iter()).await { - Err(ignored) => { - assert_eq!(ignored.len(), 1); - assert!(ignored.contains(&peer2)); - }, - _ => panic!("call was supposed to fail"), - } - - // accept for `peer3` - assert_eq!( - handle3.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle3.send_validation_result(peer1, ValidationResult::Accept); - - // accept inbound substream for `peer3` - let mut peer3_validated = false; - let mut peer3_opened = false; - - while !peer3_validated || !peer3_opened { - match handle1.next().await.unwrap() { - NotificationEvent::ValidateSubstream { protocol, fallback, peer, handshake } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); - assert_eq!(handshake, vec![1, 2, 3, 4]); - assert_eq!(fallback, None); - assert_eq!(peer, peer3); - - if !peer3_validated { - peer3_validated = true; - } else { - panic!("received an event from an unexpected peer"); - } - - handle1.send_validation_result(peer, ValidationResult::Accept); - }, - NotificationEvent::NotificationStreamOpened { peer, .. } => { - assert_eq!(peer, peer3); - - if !peer3_opened { - peer3_opened = true; - } else { - panic!("received an event from an unexpected peer"); - } - }, - _ => panic!("invalid event"), - } - } - - // verify the substream is opened for `peer2` and `peer3` - assert_eq!( - handle3.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - - // close substreams to `peer2` and `peer3` - handle1.close_substream_batch(vec![peer2, peer3].into_iter()).await; - - // verify the substream is closed for `peer2` and `peer3` - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamClosed { peer: peer1 } - ); - assert_eq!( - handle3.next().await.unwrap(), - NotificationEvent::NotificationStreamClosed { peer: peer1 } - ); - - // verify `peer1` receives close events for both peers - let mut peer2_closed = false; - let mut peer3_closed = false; - - while !peer2_closed || !peer3_closed { - match handle1.next().await.unwrap() { - NotificationEvent::NotificationStreamClosed { peer } => { - if peer == peer2 && !peer2_closed { - peer2_closed = true; - } else if peer == peer3 && !peer3_closed { - peer3_closed = true; - } else { - panic!("received an event from an unexpected peer"); - } - }, - _ => panic!("invalid event"), - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut litep2p1, mut handle1) = make_default_litep2p(transport1).await; + let (mut litep2p2, mut handle2) = make_default_litep2p(transport2).await; + let (mut litep2p3, mut handle3) = make_default_litep2p(transport3).await; + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + let peer3 = *litep2p3.local_peer_id(); + + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + let address3 = litep2p3.listen_addresses().next().unwrap().clone(); + litep2p1.add_known_address(peer2, std::iter::once(address2)); + litep2p1.add_known_address(peer3, std::iter::once(address3)); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + _ = litep2p3.next_event() => {}, + } + } + }); + + // open substream to `peer2`. + handle1.open_substream_batch(vec![peer2].into_iter()).await.unwrap(); + + // accept for `peer2` + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // accept inbound substream for `peer2` + let mut peer2_validated = false; + let mut peer2_opened = false; + + while !peer2_validated || !peer2_opened { + match handle1.next().await.unwrap() { + NotificationEvent::ValidateSubstream { + protocol, + fallback, + peer, + handshake, + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(handshake, vec![1, 2, 3, 4]); + assert_eq!(fallback, None); + assert_eq!(peer, peer2); + + if !peer2_validated { + peer2_validated = true; + } else { + panic!("received an event from an unexpected peer"); + } + + handle1.send_validation_result(peer, ValidationResult::Accept); + } + NotificationEvent::NotificationStreamOpened { peer, .. } => { + assert_eq!(peer, peer2); + + if !peer2_opened { + peer2_opened = true; + } else { + panic!("received an event from an unexpected peer"); + } + } + _ => panic!("invalid event"), + } + } + + // verify the substream is opened for `peer2` + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + + // batch another substream open command but this time include `peer2` for which + // a connection is already open + match handle1.open_substream_batch(vec![peer2, peer3].into_iter()).await { + Err(ignored) => { + assert_eq!(ignored.len(), 1); + assert!(ignored.contains(&peer2)); + } + _ => panic!("call was supposed to fail"), + } + + // accept for `peer3` + assert_eq!( + handle3.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle3.send_validation_result(peer1, ValidationResult::Accept); + + // accept inbound substream for `peer3` + let mut peer3_validated = false; + let mut peer3_opened = false; + + while !peer3_validated || !peer3_opened { + match handle1.next().await.unwrap() { + NotificationEvent::ValidateSubstream { + protocol, + fallback, + peer, + handshake, + } => { + assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(handshake, vec![1, 2, 3, 4]); + assert_eq!(fallback, None); + assert_eq!(peer, peer3); + + if !peer3_validated { + peer3_validated = true; + } else { + panic!("received an event from an unexpected peer"); + } + + handle1.send_validation_result(peer, ValidationResult::Accept); + } + NotificationEvent::NotificationStreamOpened { peer, .. } => { + assert_eq!(peer, peer3); + + if !peer3_opened { + peer3_opened = true; + } else { + panic!("received an event from an unexpected peer"); + } + } + _ => panic!("invalid event"), + } + } + + // verify the substream is opened for `peer2` and `peer3` + assert_eq!( + handle3.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + + // close substreams to `peer2` and `peer3` + handle1.close_substream_batch(vec![peer2, peer3].into_iter()).await; + + // verify the substream is closed for `peer2` and `peer3` + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamClosed { peer: peer1 } + ); + assert_eq!( + handle3.next().await.unwrap(), + NotificationEvent::NotificationStreamClosed { peer: peer1 } + ); + + // verify `peer1` receives close events for both peers + let mut peer2_closed = false; + let mut peer3_closed = false; + + while !peer2_closed || !peer3_closed { + match handle1.next().await.unwrap() { + NotificationEvent::NotificationStreamClosed { peer } => { + if peer == peer2 && !peer2_closed { + peer2_closed = true; + } else if peer == peer3 && !peer3_closed { + peer3_closed = true; + } else { + panic!("received an event from an unexpected peer"); + } + } + _ => panic!("invalid event"), + } + } } #[tokio::test] async fn no_listener_address_for_one_peer_tcp() { - no_listener_address_for_one_peer( - Transport::Tcp(TcpConfig { listen_addresses: vec![], ..Default::default() }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await + no_listener_address_for_one_peer( + Transport::Tcp(TcpConfig { + listen_addresses: vec![], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await } #[tokio::test] async fn no_listener_address_for_one_peer_quic() { - no_listener_address_for_one_peer( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + no_listener_address_for_one_peer( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn no_listener_address_for_one_peer_websocket() { - no_listener_address_for_one_peer( - Transport::WebSocket(WebSocketConfig { listen_addresses: vec![], ..Default::default() }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + no_listener_address_for_one_peer( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec![], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn no_listener_address_for_one_peer(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (mut litep2p1, mut handle1) = make_default_litep2p(transport1).await; - let (mut litep2p2, mut handle2) = make_default_litep2p(transport2).await; - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - let address2 = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.add_known_address(peer2, std::iter::once(address2)); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - handle1.open_substream(peer2).await.unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - handle1.send_validation_result(peer2, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 2, 3, 4], - } - ); - - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut litep2p1, mut handle1) = make_default_litep2p(transport1).await; + let (mut litep2p2, mut handle2) = make_default_litep2p(transport2).await; + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + let address2 = litep2p2.listen_addresses().next().unwrap().clone(); + litep2p1.add_known_address(peer2, std::iter::once(address2)); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + handle1.open_substream(peer2).await.unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + handle1.send_validation_result(peer2, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 2, 3, 4], + } + ); + + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); } #[tokio::test] async fn auto_accept_inbound_tcp() { - auto_accept_inbound(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())) - .await + auto_accept_inbound( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn auto_accept_inbound_quic() { - auto_accept_inbound(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + auto_accept_inbound( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn auto_accept_inbound_websocket() { - auto_accept_inbound( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + auto_accept_inbound( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn auto_accept_inbound(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .with_sync_channel_size(1024usize) - .with_async_channel_size(1024usize) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (mut notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .with_sync_channel_size(1024usize) - .with_async_channel_size(1024usize) - .build(); - - // set new handshake for the config - notif_config2.set_handshake(vec![1, 3, 3, 7]); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected and spawn the litep2p objects in the background - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - direction: Direction::Inbound, - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - direction: Direction::Outbound, - peer: peer2, - handshake: vec![1, 3, 3, 7], - } - ); - - handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); - handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer1, - notification: BytesMut::from(&[1, 3, 3, 7][..]), - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationReceived { - peer: peer2, - notification: BytesMut::from(&[1, 3, 3, 8][..]), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .with_sync_channel_size(1024usize) + .with_async_channel_size(1024usize) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (mut notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .with_sync_channel_size(1024usize) + .with_async_channel_size(1024usize) + .build(); + + // set new handshake for the config + notif_config2.set_handshake(vec![1, 3, 3, 7]); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected and spawn the litep2p objects in the background + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + direction: Direction::Inbound, + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + direction: Direction::Outbound, + peer: peer2, + handshake: vec![1, 3, 3, 7], + } + ); + + handle1.send_sync_notification(peer2, vec![1, 3, 3, 7]).unwrap(); + handle2.send_sync_notification(peer1, vec![1, 3, 3, 8]).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer1, + notification: BytesMut::from(&[1, 3, 3, 7][..]), + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationReceived { + peer: peer2, + notification: BytesMut::from(&[1, 3, 3, 8][..]), + } + ); } #[tokio::test] async fn dial_failure_tcp() { - dial_failure(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())).await + dial_failure( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn dial_failure_quic() { - dial_failure(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + dial_failure( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dial_failure_websocket() { - dial_failure( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + dial_failure( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn dial_failure(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .with_sync_channel_size(1024usize) - .with_async_channel_size(1024usize) - .build(); - let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) - .with_max_size(1024usize) - .with_handshake(vec![7, 7, 7, 7]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1) - .with_notification_protocol(notif_config2); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config3, _handle3) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .with_sync_channel_size(1024usize) - .with_async_channel_size(1024usize) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config3); - - let known_address = match &transport2 { - Transport::Tcp(_) => Multiaddr::empty() - .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(5)), - Transport::Quic(_) => Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(5)) - .with(Protocol::QuicV1), - Transport::WebSocket(_) => Multiaddr::empty() - .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(5)) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), - }; - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer2 = *litep2p2.local_peer_id(); - let known_address = known_address.with(Protocol::P2p(Multihash::from(peer2))); - - litep2p1.add_known_address(peer2, vec![known_address].into_iter()); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::DialFailure, - } - ); - - futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .with_sync_channel_size(1024usize) + .with_async_channel_size(1024usize) + .build(); + let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) + .with_max_size(1024usize) + .with_handshake(vec![7, 7, 7, 7]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1) + .with_notification_protocol(notif_config2); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config3, _handle3) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .with_sync_channel_size(1024usize) + .with_async_channel_size(1024usize) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config3); + + let known_address = match &transport2 { + Transport::Tcp(_) => Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(5)), + Transport::Quic(_) => Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(5)) + .with(Protocol::QuicV1), + Transport::WebSocket(_) => Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(5)) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))), + }; + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer2 = *litep2p2.local_peer_id(); + let known_address = known_address.with(Protocol::P2p(Multihash::from(peer2))); + + litep2p1.add_known_address(peer2, vec![known_address].into_iter()); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::DialFailure, + } + ); + + futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; } #[tokio::test] async fn dialing_disabled_tcp() { - dialing_disabled(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())).await + dialing_disabled( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn dialing_disabled_quic() { - dialing_disabled(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + dialing_disabled( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dialing_disabled_websocket() { - dialing_disabled( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + dialing_disabled( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn dialing_disabled(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .with_sync_channel_size(1024usize) - .with_async_channel_size(1024usize) - .with_dialing_enabled(false) - .build(); - let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) - .with_max_size(1024usize) - .with_handshake(vec![7, 7, 7, 7]) - .with_dialing_enabled(false) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1) - .with_notification_protocol(notif_config2); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config3, _handle3) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .with_sync_channel_size(1024usize) - .with_async_channel_size(1024usize) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config3); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer2 = *litep2p2.local_peer_id(); - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - - litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::DialFailure, - } - ); - - futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - _ => panic!("invalid event"), - }) - .await; + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .with_sync_channel_size(1024usize) + .with_async_channel_size(1024usize) + .with_dialing_enabled(false) + .build(); + let (notif_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/2")) + .with_max_size(1024usize) + .with_handshake(vec![7, 7, 7, 7]) + .with_dialing_enabled(false) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1) + .with_notification_protocol(notif_config2); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config3, _handle3) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .with_sync_channel_size(1024usize) + .with_async_channel_size(1024usize) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config3); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer2 = *litep2p2.local_peer_id(); + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + + litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::DialFailure, + } + ); + + futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + _ => panic!("invalid event"), + }) + .await; } #[tokio::test] async fn validation_takes_too_long_tcp() { - validation_takes_too_long( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await + validation_takes_too_long( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn validation_takes_too_long_quic() { - validation_takes_too_long( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + validation_takes_too_long( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn validation_takes_too_long_websocket() { - validation_takes_too_long( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + validation_takes_too_long( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn validation_takes_too_long(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config3, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config3); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - - litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected, - } - ); - - // give theh connection a moment to close - tokio::time::sleep(Duration::from_secs(5)).await; - - handle2.send_validation_result(peer1, ValidationResult::Accept); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer1, - error: NotificationError::NoConnection, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config3, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config3); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + + litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected, + } + ); + + // give theh connection a moment to close + tokio::time::sleep(Duration::from_secs(5)).await; + + handle2.send_validation_result(peer1, ValidationResult::Accept); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer1, + error: NotificationError::NoConnection, + } + ); } #[tokio::test] async fn ignored_validation_open_substream_tcp() { - ignored_validation_open_substream( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await + ignored_validation_open_substream( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn ignored_validation_open_substream_quic() { - ignored_validation_open_substream( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + ignored_validation_open_substream( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn ignored_validation_open_substream_websocket() { - ignored_validation_open_substream( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + ignored_validation_open_substream( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn ignored_validation_open_substream(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config3, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(1024usize) - .with_handshake(vec![1, 2, 3, 4]) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config3); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - - litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected, - } - ); - - // wait a moment to allow the connection to close - tokio::time::sleep(Duration::from_secs(2)).await; - - // verify that there are no events pending - futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("invalid event: {event:?}"), - }) - .await; - - // try to open a substream while the previous validation is still in progress - // and verify that the substream is rejected with `ValidationPending` - handle2.open_substream(peer1).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer1, - error: NotificationError::ValidationPending, - } - ); - - // try to open substream as `peer1` and verify the inbound substream gets rejected - // because the previous substream is still pending - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer2, - error: NotificationError::Rejected, - } - ); - - // verify `peer2` is not notified of the new substream - futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - event => panic!("invalid event: {event:?}"), - }) - .await; - - // finally try to accept the original substream and verify it fails to open with `NoConnection` - handle2.send_validation_result(peer1, ValidationResult::Accept); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpenFailure { - peer: peer1, - error: NotificationError::Rejected, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config3, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(1024usize) + .with_handshake(vec![1, 2, 3, 4]) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config3); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + + litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected, + } + ); + + // wait a moment to allow the connection to close + tokio::time::sleep(Duration::from_secs(2)).await; + + // verify that there are no events pending + futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("invalid event: {event:?}"), + }) + .await; + + // try to open a substream while the previous validation is still in progress + // and verify that the substream is rejected with `ValidationPending` + handle2.open_substream(peer1).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer1, + error: NotificationError::ValidationPending, + } + ); + + // try to open substream as `peer1` and verify the inbound substream gets rejected + // because the previous substream is still pending + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer2, + error: NotificationError::Rejected, + } + ); + + // verify `peer2` is not notified of the new substream + futures::future::poll_fn(|cx| match handle2.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + event => panic!("invalid event: {event:?}"), + }) + .await; + + // finally try to accept the original substream and verify it fails to open with `NoConnection` + handle2.send_validation_result(peer1, ValidationResult::Accept); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpenFailure { + peer: peer1, + error: NotificationError::Rejected, + } + ); } #[tokio::test] async fn clogged_channel_disconnects_peer_tcp() { - clogged_channel_disconnects_peer( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await + clogged_channel_disconnects_peer( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await } #[tokio::test] async fn clogged_channel_disconnects_peer_quic() { - clogged_channel_disconnects_peer( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + clogged_channel_disconnects_peer( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn clogged_channel_disconnects_peer_websocket() { - clogged_channel_disconnects_peer( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + clogged_channel_disconnects_peer( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn clogged_channel_disconnects_peer(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(100 * 1024) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (notif_config3, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) - .with_max_size(100 * 1024) - .with_handshake(vec![1, 2, 3, 4]) - .with_auto_accept_inbound(true) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_notification_protocol(notif_config3); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); - - litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // open substream for `peer2` and accept it - handle1.open_substream(peer2).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::ValidateSubstream { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - } - ); - handle2.send_validation_result(peer1, ValidationResult::Accept); - - // verify both peers have the substream open - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer2, - handshake: vec![1, 2, 3, 4], - direction: Direction::Outbound, - } - ); - assert_eq!( - handle2.next().await.unwrap(), - NotificationEvent::NotificationStreamOpened { - protocol: ProtocolName::from("/notif/1"), - fallback: None, - peer: peer1, - handshake: vec![1, 2, 3, 4], - direction: Direction::Inbound, - } - ); - - // start sending notifications to `peer2` which never reads them, - // causing `peer1` to consume all available credit - loop { - match handle1.send_sync_notification(peer2, vec![0u8; 99 * 1024]) { - Ok(()) => {}, - Err(NotificationError::ChannelClogged) => break, - error => panic!("invalid error: {error:?}"), - } - } - - // stream closed from `peer1`'s PoV - assert_eq!( - handle1.next().await.unwrap(), - NotificationEvent::NotificationStreamClosed { peer: peer2 }, - ); - - // `peer2` is also reported that the substream is closed - match tokio::time::timeout(Duration::from_secs(5), async move { - loop { - if let Some(NotificationEvent::NotificationStreamClosed { peer }) = handle2.next().await - { - assert_eq!(peer, peer1); - break; - } - } - }) - .await - { - Err(_) => panic!("timeout"), - Ok(()) => {}, - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (notif_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(100 * 1024) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (notif_config3, mut handle2) = ConfigBuilder::new(ProtocolName::from("/notif/1")) + .with_max_size(100 * 1024) + .with_handshake(vec![1, 2, 3, 4]) + .with_auto_accept_inbound(true) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_notification_protocol(notif_config3); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + let listen_address = litep2p2.listen_addresses().next().unwrap().clone(); + + litep2p1.add_known_address(peer2, vec![listen_address].into_iter()); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // open substream for `peer2` and accept it + handle1.open_substream(peer2).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::ValidateSubstream { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + } + ); + handle2.send_validation_result(peer1, ValidationResult::Accept); + + // verify both peers have the substream open + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer2, + handshake: vec![1, 2, 3, 4], + direction: Direction::Outbound, + } + ); + assert_eq!( + handle2.next().await.unwrap(), + NotificationEvent::NotificationStreamOpened { + protocol: ProtocolName::from("/notif/1"), + fallback: None, + peer: peer1, + handshake: vec![1, 2, 3, 4], + direction: Direction::Inbound, + } + ); + + // start sending notifications to `peer2` which never reads them, + // causing `peer1` to consume all available credit + loop { + match handle1.send_sync_notification(peer2, vec![0u8; 99 * 1024]) { + Ok(()) => {} + Err(NotificationError::ChannelClogged) => break, + error => panic!("invalid error: {error:?}"), + } + } + + // stream closed from `peer1`'s PoV + assert_eq!( + handle1.next().await.unwrap(), + NotificationEvent::NotificationStreamClosed { peer: peer2 }, + ); + + // `peer2` is also reported that the substream is closed + match tokio::time::timeout(Duration::from_secs(5), async move { + loop { + if let Some(NotificationEvent::NotificationStreamClosed { peer }) = handle2.next().await + { + assert_eq!(peer, peer1); + break; + } + } + }) + .await + { + Err(_) => panic!("timeout"), + Ok(()) => {} + } } diff --git a/tests/protocol/ping.rs b/tests/protocol/ping.rs index 9c96a4ee..32beb41d 100644 --- a/tests/protocol/ping.rs +++ b/tests/protocol/ping.rs @@ -20,93 +20,101 @@ use futures::StreamExt; use litep2p::{ - config::ConfigBuilder, - protocol::libp2p::ping::ConfigBuilder as PingConfigBuilder, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - websocket::config::Config as WebSocketConfig, - }, - Litep2p, + config::ConfigBuilder, + protocol::libp2p::ping::ConfigBuilder as PingConfigBuilder, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + websocket::config::Config as WebSocketConfig, + }, + Litep2p, }; enum Transport { - Tcp(TcpConfig), - Quic(QuicConfig), - WebSocket(WebSocketConfig), + Tcp(TcpConfig), + Quic(QuicConfig), + WebSocket(WebSocketConfig), } #[tokio::test] async fn ping_supported_tcp() { - ping_supported(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())).await; + ping_supported( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn ping_supported_websocket() { - ping_supported( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + ping_supported( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } #[tokio::test] async fn ping_supported_quic() { - ping_supported(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + ping_supported( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } async fn ping_supported(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); - let (ping_config1, mut ping_event_stream1) = - PingConfigBuilder::new().with_max_failure(3usize).build(); - let config1 = match transport1 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_libp2p_ping(ping_config1) - .build(); + let (ping_config1, mut ping_event_stream1) = + PingConfigBuilder::new().with_max_failure(3usize).build(); + let config1 = match transport1 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_libp2p_ping(ping_config1) + .build(); - let (ping_config2, mut ping_event_stream2) = PingConfigBuilder::new().build(); - let config2 = match transport2 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_libp2p_ping(ping_config2) - .build(); + let (ping_config2, mut ping_event_stream2) = PingConfigBuilder::new().build(); + let config2 = match transport2 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_libp2p_ping(ping_config2) + .build(); - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let address = litep2p2.listen_addresses().next().unwrap().clone(); + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let address = litep2p2.listen_addresses().next().unwrap().clone(); - litep2p1.dial_address(address).await.unwrap(); + litep2p1.dial_address(address).await.unwrap(); - let mut litep2p1_done = false; - let mut litep2p2_done = false; + let mut litep2p1_done = false; + let mut litep2p2_done = false; - loop { - tokio::select! { - _event = litep2p1.next_event() => {} - _event = litep2p2.next_event() => {} - event = ping_event_stream1.next() => { - tracing::trace!("ping event for litep2p1: {event:?}"); + loop { + tokio::select! { + _event = litep2p1.next_event() => {} + _event = litep2p2.next_event() => {} + event = ping_event_stream1.next() => { + tracing::trace!("ping event for litep2p1: {event:?}"); - litep2p1_done = true; - if litep2p1_done && litep2p2_done { - break - } - } - event = ping_event_stream2.next() => { - tracing::trace!("ping event for litep2p2: {event:?}"); + litep2p1_done = true; + if litep2p1_done && litep2p2_done { + break + } + } + event = ping_event_stream2.next() => { + tracing::trace!("ping event for litep2p2: {event:?}"); - litep2p2_done = true; - if litep2p1_done && litep2p2_done { - break - } - } - } - } + litep2p2_done = true; + if litep2p1_done && litep2p2_done { + break + } + } + } + } } diff --git a/tests/protocol/request_response.rs b/tests/protocol/request_response.rs index 64067d23..4d15e42e 100644 --- a/tests/protocol/request_response.rs +++ b/tests/protocol/request_response.rs @@ -19,18 +19,18 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - config::ConfigBuilder as Litep2pConfigBuilder, - crypto::ed25519::Keypair, - protocol::request_response::{ - Config as RequestResponseConfig, ConfigBuilder, DialOptions, RequestResponseError, - RequestResponseEvent, - }, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - websocket::config::Config as WebSocketConfig, - }, - types::{protocol::ProtocolName, RequestId}, - Litep2p, Litep2pEvent, PeerId, + config::ConfigBuilder as Litep2pConfigBuilder, + crypto::ed25519::Keypair, + protocol::request_response::{ + Config as RequestResponseConfig, ConfigBuilder, DialOptions, RequestResponseError, + RequestResponseEvent, + }, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + websocket::config::Config as WebSocketConfig, + }, + types::{protocol::ProtocolName, RequestId}, + Litep2p, Litep2pEvent, PeerId, }; use futures::{channel, StreamExt}; @@ -41,2096 +41,2127 @@ use rand_xorshift::XorShiftRng; use tokio::time::sleep; use std::{ - collections::{HashMap, HashSet}, - net::{Ipv4Addr, Ipv6Addr}, - task::Poll, - time::Duration, + collections::{HashMap, HashSet}, + net::{Ipv4Addr, Ipv6Addr}, + task::Poll, + time::Duration, }; enum Transport { - Tcp(TcpConfig), - Quic(QuicConfig), - WebSocket(WebSocketConfig), + Tcp(TcpConfig), + Quic(QuicConfig), + WebSocket(WebSocketConfig), } async fn connect_peers(litep2p1: &mut Litep2p, litep2p2: &mut Litep2p) { - let address = litep2p2.listen_addresses().next().unwrap().clone(); - tracing::info!("address: {address}"); - litep2p1.dial_address(address).await.unwrap(); - - let mut litep2p1_connected = false; - let mut litep2p2_connected = false; - - loop { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p1_connected = true; - } - _ => {}, - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p2_connected = true; - } - _ => {}, - } - } - - if litep2p1_connected && litep2p2_connected { - break; - } - } - - sleep(Duration::from_millis(100)).await; + let address = litep2p2.listen_addresses().next().unwrap().clone(); + tracing::info!("address: {address}"); + litep2p1.dial_address(address).await.unwrap(); + + let mut litep2p1_connected = false; + let mut litep2p2_connected = false; + + loop { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p1_connected = true; + } + _ => {}, + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p2_connected = true; + } + _ => {}, + } + } + + if litep2p1_connected && litep2p2_connected { + break; + } + } + + sleep(Duration::from_millis(100)).await; } #[tokio::test] async fn send_request_receive_response_tcp() { - send_request_receive_response( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + send_request_receive_response( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn send_request_receive_response_quic() { - send_request_receive_response( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + send_request_receive_response( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn send_request_receive_response_websocket() { - send_request_receive_response( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + send_request_receive_response( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn send_request_receive_response(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 3, 3, 7], - } - ); - - // send response to the received request - handle2.send_response(request_id, vec![1, 3, 3, 8]); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 8], - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 3, 3, 7], + } + ); + + // send response to the received request + handle2.send_response(request_id, vec![1, 3, 3, 8]); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 8], + fallback: None, + } + ); } #[tokio::test] async fn reject_request_tcp() { - reject_request( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + reject_request( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn reject_request_quic() { - reject_request(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + reject_request( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn reject_request_websocket() { - reject_request( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + reject_request( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn reject_request(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - if let RequestResponseEvent::RequestReceived { peer, fallback: None, request_id, request } = - handle2.next().await.unwrap() - { - assert_eq!(peer, peer1); - assert_eq!(request, vec![1, 3, 3, 7]); - handle2.reject_request(request_id); - } else { - panic!("invalid event received"); - }; - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Rejected - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + if let RequestResponseEvent::RequestReceived { + peer, + fallback: None, + request_id, + request, + } = handle2.next().await.unwrap() + { + assert_eq!(peer, peer1); + assert_eq!(request, vec![1, 3, 3, 7]); + handle2.reject_request(request_id); + } else { + panic!("invalid event received"); + }; + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Rejected + } + ); } #[tokio::test] async fn multiple_simultaneous_requests_tcp() { - multiple_simultaneous_requests( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + multiple_simultaneous_requests( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn multiple_simultaneous_requests_quic() { - multiple_simultaneous_requests( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + multiple_simultaneous_requests( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn multiple_simultaneous_requests_websocket() { - multiple_simultaneous_requests( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + multiple_simultaneous_requests( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn multiple_simultaneous_requests(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send multiple requests to remote peer - let request_id1 = handle1 - .send_request(peer2, vec![1, 3, 3, 6], DialOptions::Reject) - .await - .unwrap(); - let request_id2 = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - let request_id3 = handle1 - .send_request(peer2, vec![1, 3, 3, 8], DialOptions::Reject) - .await - .unwrap(); - let request_id4 = handle1 - .send_request(peer2, vec![1, 3, 3, 9], DialOptions::Reject) - .await - .unwrap(); - let expected: HashMap> = HashMap::from_iter([ - (request_id1, vec![2, 3, 3, 6]), - (request_id2, vec![2, 3, 3, 7]), - (request_id3, vec![2, 3, 3, 8]), - (request_id4, vec![2, 3, 3, 9]), - ]); - let expected_requests: Vec> = - vec![vec![1, 3, 3, 6], vec![1, 3, 3, 7], vec![1, 3, 3, 8], vec![1, 3, 3, 9]]; - - for _ in 0..4 { - if let RequestResponseEvent::RequestReceived { - peer, - fallback: None, - request_id, - mut request, - } = handle2.next().await.unwrap() - { - assert_eq!(peer, peer1); - if expected_requests.iter().any(|req| req == &request) { - request[0] = 2; - handle2.send_response(request_id, request); - } else { - panic!("invalid request received"); - } - } else { - panic!("invalid event received"); - }; - } - - for _ in 0..4 { - if let RequestResponseEvent::ResponseReceived { peer, request_id, response, .. } = - handle1.next().await.unwrap() - { - assert_eq!(peer, peer2); - assert_eq!(response, expected.get(&request_id).unwrap().to_vec()); - } else { - panic!("invalid event received"); - }; - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send multiple requests to remote peer + let request_id1 = handle1 + .send_request(peer2, vec![1, 3, 3, 6], DialOptions::Reject) + .await + .unwrap(); + let request_id2 = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + let request_id3 = handle1 + .send_request(peer2, vec![1, 3, 3, 8], DialOptions::Reject) + .await + .unwrap(); + let request_id4 = handle1 + .send_request(peer2, vec![1, 3, 3, 9], DialOptions::Reject) + .await + .unwrap(); + let expected: HashMap> = HashMap::from_iter([ + (request_id1, vec![2, 3, 3, 6]), + (request_id2, vec![2, 3, 3, 7]), + (request_id3, vec![2, 3, 3, 8]), + (request_id4, vec![2, 3, 3, 9]), + ]); + let expected_requests: Vec> = vec![ + vec![1, 3, 3, 6], + vec![1, 3, 3, 7], + vec![1, 3, 3, 8], + vec![1, 3, 3, 9], + ]; + + for _ in 0..4 { + if let RequestResponseEvent::RequestReceived { + peer, + fallback: None, + request_id, + mut request, + } = handle2.next().await.unwrap() + { + assert_eq!(peer, peer1); + if expected_requests.iter().any(|req| req == &request) { + request[0] = 2; + handle2.send_response(request_id, request); + } else { + panic!("invalid request received"); + } + } else { + panic!("invalid event received"); + }; + } + + for _ in 0..4 { + if let RequestResponseEvent::ResponseReceived { + peer, + request_id, + response, + .. + } = handle1.next().await.unwrap() + { + assert_eq!(peer, peer2); + assert_eq!(response, expected.get(&request_id).unwrap().to_vec()); + } else { + panic!("invalid event received"); + }; + } } #[tokio::test] async fn request_timeout_tcp() { - request_timeout( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + request_timeout( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn request_timeout_quic() { - request_timeout(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + request_timeout( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn request_timeout_websocket() { - request_timeout( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + request_timeout( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } // TODO: configure longer keep-alive timeout for the protocol async fn request_timeout(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let _peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer and wait until the requet timeout occurs - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - sleep(Duration::from_secs(7)).await; - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Timeout, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let _peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer and wait until the requet timeout occurs + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + sleep(Duration::from_secs(7)).await; + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Timeout, + } + ); } #[tokio::test] async fn protocol_not_supported_tcp() { - protocol_not_supported( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + protocol_not_supported( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn protocol_not_supported_quic() { - protocol_not_supported( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + protocol_not_supported( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn protocol_not_supported_websocket() { - protocol_not_supported( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + protocol_not_supported( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn protocol_not_supported(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/2"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let _peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer and wait until the requet timeout occurs - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::UnsupportedProtocol, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/2"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let _peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer and wait until the requet timeout occurs + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::UnsupportedProtocol, + } + ); } #[tokio::test] async fn connection_close_while_request_is_pending_tcp() { - connection_close_while_request_is_pending( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + connection_close_while_request_is_pending( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn connection_close_while_request_is_pending_quic() { - connection_close_while_request_is_pending( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + connection_close_while_request_is_pending( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn connection_close_while_request_is_pending_websocket() { - connection_close_while_request_is_pending( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + connection_close_while_request_is_pending( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn connection_close_while_request_is_pending(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let _peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - let _ = litep2p1.next_event().await; - } - }); - - // send request to remote peer and wait until the requet timeout occurs - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - drop(handle2); - drop(litep2p2); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Rejected, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let _peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + let _ = litep2p1.next_event().await; + } + }); + + // send request to remote peer and wait until the requet timeout occurs + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + drop(handle2); + drop(litep2p2); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Rejected, + } + ); } #[tokio::test] async fn request_too_big_tcp() { - request_too_big( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + request_too_big( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn request_too_big_quic() { - request_too_big(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + request_too_big( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn request_too_big_websocket() { - request_too_big( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + request_too_big( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn request_too_big(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 256, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // try to send too large request to remote peer - let request_id = - handle1.send_request(peer2, vec![0u8; 257], DialOptions::Reject).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::TooLargePayload, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 256, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // try to send too large request to remote peer + let request_id = + handle1.send_request(peer2, vec![0u8; 257], DialOptions::Reject).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::TooLargePayload, + } + ); } #[tokio::test] async fn response_too_big_tcp() { - response_too_big( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + response_too_big( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn response_too_big_quic() { - response_too_big(Transport::Quic(Default::default()), Transport::Quic(Default::default())) - .await; + response_too_big( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn response_too_big_websocket() { - response_too_big( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + response_too_big( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn response_too_big(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 256, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 256, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = - handle1.send_request(peer2, vec![0u8; 256], DialOptions::Reject).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![0u8; 256], - } - ); - - // try to send too large response to the received request - handle2.send_response(request_id, vec![0u8; 257]); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Rejected, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 256, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 256, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = + handle1.send_request(peer2, vec![0u8; 256], DialOptions::Reject).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![0u8; 256], + } + ); + + // try to send too large response to the received request + handle2.send_response(request_id, vec![0u8; 257]); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Rejected, + } + ); } #[tokio::test] async fn too_many_pending_requests() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let mut yamux_config = litep2p::yamux::Config::default(); - yamux_config.set_max_num_streams(4); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_request_response_protocol(req_resp_config1) - .build(); - - let (req_resp_config2, _handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let mut yamux_config = litep2p::yamux::Config::default(); - yamux_config.set_max_num_streams(4); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }) - .with_request_response_protocol(req_resp_config2) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - - // send one over the max requests to remote peer - let mut request_ids = HashSet::new(); - - request_ids.insert( - handle1 - .send_request(peer2, vec![1, 3, 3, 6], DialOptions::Reject) - .await - .unwrap(), - ); - request_ids.insert( - handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(), - ); - request_ids.insert( - handle1 - .send_request(peer2, vec![1, 3, 3, 8], DialOptions::Reject) - .await - .unwrap(), - ); - request_ids.insert( - handle1 - .send_request(peer2, vec![1, 3, 3, 9], DialOptions::Reject) - .await - .unwrap(), - ); - request_ids.insert( - handle1 - .send_request(peer2, vec![1, 3, 3, 9], DialOptions::Reject) - .await - .unwrap(), - ); - - let mut litep2p1_closed = false; - let mut litep2p2_closed = false; - - while !litep2p1_closed || !litep2p2_closed || !request_ids.is_empty() { - tokio::select! { - event = litep2p1.next_event() => match event { - Some(Litep2pEvent::ConnectionClosed { .. }) => { - litep2p1_closed = true; - } - _ => {} - }, - event = litep2p2.next_event() => match event { - Some(Litep2pEvent::ConnectionClosed { .. }) => { - litep2p2_closed = true; - } - _ => {} - }, - event = handle1.next() => match event { - Some(RequestResponseEvent::RequestFailed { - request_id, - .. - }) => { - request_ids.remove(&request_id); - } - _ => {} - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let mut yamux_config = litep2p::yamux::Config::default(); + yamux_config.set_max_num_streams(4); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_request_response_protocol(req_resp_config1) + .build(); + + let (req_resp_config2, _handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let mut yamux_config = litep2p::yamux::Config::default(); + yamux_config.set_max_num_streams(4); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }) + .with_request_response_protocol(req_resp_config2) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + + // send one over the max requests to remote peer + let mut request_ids = HashSet::new(); + + request_ids.insert( + handle1 + .send_request(peer2, vec![1, 3, 3, 6], DialOptions::Reject) + .await + .unwrap(), + ); + request_ids.insert( + handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(), + ); + request_ids.insert( + handle1 + .send_request(peer2, vec![1, 3, 3, 8], DialOptions::Reject) + .await + .unwrap(), + ); + request_ids.insert( + handle1 + .send_request(peer2, vec![1, 3, 3, 9], DialOptions::Reject) + .await + .unwrap(), + ); + request_ids.insert( + handle1 + .send_request(peer2, vec![1, 3, 3, 9], DialOptions::Reject) + .await + .unwrap(), + ); + + let mut litep2p1_closed = false; + let mut litep2p2_closed = false; + + while !litep2p1_closed || !litep2p2_closed || !request_ids.is_empty() { + tokio::select! { + event = litep2p1.next_event() => match event { + Some(Litep2pEvent::ConnectionClosed { .. }) => { + litep2p1_closed = true; + } + _ => {} + }, + event = litep2p2.next_event() => match event { + Some(Litep2pEvent::ConnectionClosed { .. }) => { + litep2p2_closed = true; + } + _ => {} + }, + event = handle1.next() => match event { + Some(RequestResponseEvent::RequestFailed { + request_id, + .. + }) => { + request_ids.remove(&request_id); + } + _ => {} + } + } + } } #[tokio::test] async fn dialer_fallback_protocol_works_tcp() { - dialer_fallback_protocol_works( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dialer_fallback_protocol_works( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn dialer_fallback_protocol_works_quic() { - dialer_fallback_protocol_works( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + dialer_fallback_protocol_works( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dialer_fallback_protocol_works_websocket() { - dialer_fallback_protocol_works( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dialer_fallback_protocol_works( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn dialer_fallback_protocol_works(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = - ConfigBuilder::new(ProtocolName::from("/protocol/1/improved")) - .with_max_size(1024usize) - .with_fallback_names(vec![ProtocolName::from("/protocol/1")]) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 3, 3, 7], - } - ); - - // send response to the received request - handle2.send_response(request_id, vec![1, 3, 3, 8]); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 8], - fallback: Some(ProtocolName::from("/protocol/1")), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = + ConfigBuilder::new(ProtocolName::from("/protocol/1/improved")) + .with_max_size(1024usize) + .with_fallback_names(vec![ProtocolName::from("/protocol/1")]) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 3, 3, 7], + } + ); + + // send response to the received request + handle2.send_response(request_id, vec![1, 3, 3, 8]); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 8], + fallback: Some(ProtocolName::from("/protocol/1")), + } + ); } #[tokio::test] async fn listener_fallback_protocol_works_tcp() { - listener_fallback_protocol_works( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + listener_fallback_protocol_works( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn listener_fallback_protocol_works_quic() { - listener_fallback_protocol_works( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + listener_fallback_protocol_works( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn listener_fallback_protocol_works_websocket() { - listener_fallback_protocol_works( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + listener_fallback_protocol_works( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn listener_fallback_protocol_works(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1/improved"), - vec![ProtocolName::from("/protocol/1")], - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: Some(ProtocolName::from("/protocol/1")), - request_id, - request: vec![1, 3, 3, 7], - } - ); - - // send response to the received request - handle2.send_response(request_id, vec![1, 3, 3, 8]); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 8], - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1/improved"), + vec![ProtocolName::from("/protocol/1")], + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: Some(ProtocolName::from("/protocol/1")), + request_id, + request: vec![1, 3, 3, 7], + } + ); + + // send response to the received request + handle2.send_response(request_id, vec![1, 3, 3, 8]); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 8], + fallback: None, + } + ); } #[tokio::test] async fn dial_peer_when_sending_request_tcp() { - dial_peer_when_sending_request( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dial_peer_when_sending_request( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn dial_peer_when_sending_request_quic() { - dial_peer_when_sending_request( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + dial_peer_when_sending_request( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dial_peer_when_sending_request_websocket() { - dial_peer_when_sending_request( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dial_peer_when_sending_request( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn dial_peer_when_sending_request(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1/improved"), - vec![ProtocolName::from("/protocol/1")], - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - let address = litep2p2.listen_addresses().next().unwrap().clone(); - - // add known address for `peer2` and start event loop for both litep2ps - litep2p1.add_known_address(peer2, std::iter::once(address)); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - } - } - }); - - // send request to remote peer - let request_id = - handle1.send_request(peer2, vec![1, 3, 3, 7], DialOptions::Dial).await.unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: Some(ProtocolName::from("/protocol/1")), - request_id, - request: vec![1, 3, 3, 7], - } - ); - - // send response to the received request - handle2.send_response(request_id, vec![1, 3, 3, 8]); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 8], - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1/improved"), + vec![ProtocolName::from("/protocol/1")], + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + let address = litep2p2.listen_addresses().next().unwrap().clone(); + + // add known address for `peer2` and start event loop for both litep2ps + litep2p1.add_known_address(peer2, std::iter::once(address)); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + } + } + }); + + // send request to remote peer + let request_id = + handle1.send_request(peer2, vec![1, 3, 3, 7], DialOptions::Dial).await.unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: Some(ProtocolName::from("/protocol/1")), + request_id, + request: vec![1, 3, 3, 7], + } + ); + + // send response to the received request + handle2.send_response(request_id, vec![1, 3, 3, 8]); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 8], + fallback: None, + } + ); } #[tokio::test] async fn dial_peer_but_no_known_address_tcp() { - dial_peer_but_no_known_address( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dial_peer_but_no_known_address( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn dial_peer_but_no_known_address_quic() { - dial_peer_but_no_known_address( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + dial_peer_but_no_known_address( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn dial_peer_but_no_known_address_websocket() { - dial_peer_but_no_known_address( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + dial_peer_but_no_known_address( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn dial_peer_but_no_known_address(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1/improved"), - vec![ProtocolName::from("/protocol/1")], - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer2 = *litep2p2.local_peer_id(); - - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {} - _ = litep2p2.next_event() => {} - } - } - }); - - // send request to remote peer - let request_id = - handle1.send_request(peer2, vec![1, 3, 3, 7], DialOptions::Dial).await.unwrap(); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Rejected, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1/improved"), + vec![ProtocolName::from("/protocol/1")], + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer2 = *litep2p2.local_peer_id(); + + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {} + _ = litep2p2.next_event() => {} + } + } + }); + + // send request to remote peer + let request_id = + handle1.send_request(peer2, vec![1, 3, 3, 7], DialOptions::Dial).await.unwrap(); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Rejected, + } + ); } #[tokio::test] async fn cancel_request_tcp() { - cancel_request( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + cancel_request( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn cancel_request_quic() { - cancel_request(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + cancel_request( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn cancel_request_websocket() { - cancel_request( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + cancel_request( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn cancel_request(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p1.next_event() => {}, - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 3, 3, 7], - } - ); - - // cancel request - handle1.cancel_request(request_id).await; - - // try to send response to the canceled request - handle2.send_response(request_id, vec![1, 3, 3, 8]); - - // verify that nothing is receieved since the request was canceled - match tokio::time::timeout(Duration::from_secs(2), handle1.next()).await { - Err(_) => {}, - Ok(event) => panic!("invalid event received: {event:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p1.next_event() => {}, + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 3, 3, 7], + } + ); + + // cancel request + handle1.cancel_request(request_id).await; + + // try to send response to the canceled request + handle2.send_response(request_id, vec![1, 3, 3, 8]); + + // verify that nothing is receieved since the request was canceled + match tokio::time::timeout(Duration::from_secs(2), handle1.next()).await { + Err(_) => {} + Ok(event) => panic!("invalid event received: {event:?}"), + } } #[tokio::test] async fn substream_open_failure_reported_once_tcp() { - substream_open_failure_reported_once( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + substream_open_failure_reported_once( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn substream_open_failure_reported_once_quic() { - substream_open_failure_reported_once( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + substream_open_failure_reported_once( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn substream_open_failure_reported_once_websocket() { - substream_open_failure_reported_once( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + substream_open_failure_reported_once( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn substream_open_failure_reported_once(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = RequestResponseConfig::new( - ProtocolName::from("/protocol/1"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = RequestResponseConfig::new( - ProtocolName::from("/protocol/2"), - Vec::new(), - 1024, - Duration::from_secs(5), - None, - ); - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - } - } - }); - - // send request to remote peer - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::UnsupportedProtocol, - } - ); - - loop { - match litep2p1.next_event().await { - Some(Litep2pEvent::ConnectionClosed { peer, .. }) => { - assert_eq!(peer, peer2); - break; - }, - event => panic!("invalid event received: {event:?}"), - } - } - - // verify that nothing is received from the handle as the request failure was already reported - if let Ok(event) = tokio::time::timeout(Duration::from_secs(5), handle1.next()).await { - panic!("didn't expect to receive event: {event:?}"); - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = RequestResponseConfig::new( + ProtocolName::from("/protocol/1"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = RequestResponseConfig::new( + ProtocolName::from("/protocol/2"), + Vec::new(), + 1024, + Duration::from_secs(5), + None, + ); + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + } + } + }); + + // send request to remote peer + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::UnsupportedProtocol, + } + ); + + loop { + match litep2p1.next_event().await { + Some(Litep2pEvent::ConnectionClosed { peer, .. }) => { + assert_eq!(peer, peer2); + break; + } + event => panic!("invalid event received: {event:?}"), + } + } + + // verify that nothing is received from the handle as the request failure was already reported + if let Ok(event) = tokio::time::timeout(Duration::from_secs(5), handle1.next()).await { + panic!("didn't expect to receive event: {event:?}"); + } } #[tokio::test] async fn excess_inbound_request_rejected_tcp() { - excess_inbound_request_rejected( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + excess_inbound_request_rejected( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn excess_inbound_request_rejected_quic() { - excess_inbound_request_rejected( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + excess_inbound_request_rejected( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn excess_inbound_request_rejected_websocket() { - excess_inbound_request_rejected( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + excess_inbound_request_rejected( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn excess_inbound_request_rejected(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .with_max_concurrent_inbound_requests(2) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - // send two requests and verify that nothing is returned back (yet) - for _ in 0..2 { - let _ = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - } - - futures::future::poll_fn(|cx| match handle1.poll_next_unpin(cx) { - Poll::Pending => Poll::Ready(()), - Poll::Ready(_) => panic!("didn't expect an event"), - }) - .await; - - // send another request to peer and since there's two requests already pending - // and the limit was set at 2, the third request must be rejeced - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Rejected - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .with_max_concurrent_inbound_requests(2) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + // send two requests and verify that nothing is returned back (yet) + for _ in 0..2 { + let _ = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + } + + futures::future::poll_fn(|cx| match handle1.poll_next_unpin(cx) { + Poll::Pending => Poll::Ready(()), + Poll::Ready(_) => panic!("didn't expect an event"), + }) + .await; + + // send another request to peer and since there's two requests already pending + // and the limit was set at 2, the third request must be rejeced + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Rejected + } + ); } #[tokio::test] async fn feedback_received_for_succesful_response_tcp() { - feedback_received_for_succesful_response( - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - Transport::Tcp(TcpConfig { - listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], - ..Default::default() - }), - ) - .await; + feedback_received_for_succesful_response( + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + Transport::Tcp(TcpConfig { + listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()], + ..Default::default() + }), + ) + .await; } #[tokio::test] async fn feedback_received_for_succesful_response_quic() { - feedback_received_for_succesful_response( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + feedback_received_for_succesful_response( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn feedback_received_for_succesful_response_websocket() { - feedback_received_for_succesful_response( - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - Transport::WebSocket(WebSocketConfig { - listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], - ..Default::default() - }), - ) - .await; + feedback_received_for_succesful_response( + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + Transport::WebSocket(WebSocketConfig { + listen_addresses: vec!["/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()], + ..Default::default() + }), + ) + .await; } async fn feedback_received_for_succesful_response(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 3, 3, 7] - }, - ); - - // send response with feedback and verify that the response was sent successfully - let (feedback_tx, feedback_rx) = channel::oneshot::channel(); - handle2.send_response_with_feedback(request_id, vec![1, 3, 3, 8], feedback_tx); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 8], - fallback: None, - } - ); - assert!(feedback_rx.await.is_ok()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 3, 3, 7] + }, + ); + + // send response with feedback and verify that the response was sent successfully + let (feedback_tx, feedback_rx) = channel::oneshot::channel(); + handle2.send_response_with_feedback(request_id, vec![1, 3, 3, 8], feedback_tx); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 8], + fallback: None, + } + ); + assert!(feedback_rx.await.is_ok()); } // #[tokio::test] @@ -2150,11 +2181,11 @@ async fn feedback_received_for_succesful_response(transport1: Transport, transpo #[tokio::test] async fn feedback_not_received_for_failed_response_quic() { - feedback_not_received_for_failed_response( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + feedback_not_received_for_failed_response( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } // #[tokio::test] @@ -2173,986 +2204,1002 @@ async fn feedback_not_received_for_failed_response_quic() { // } async fn feedback_not_received_for_failed_response(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 3, 3, 7] - }, - ); - - // cancel the request and give a moment to register - handle1.cancel_request(request_id).await; - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - // send response with feedback and verify that sending the response fails - let (feedback_tx, feedback_rx) = channel::oneshot::channel(); - handle2.send_response_with_feedback(request_id, vec![1, 3, 3, 8], feedback_tx); - - assert!(feedback_rx.await.is_err()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 3, 3, 7] + }, + ); + + // cancel the request and give a moment to register + handle1.cancel_request(request_id).await; + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // send response with feedback and verify that sending the response fails + let (feedback_tx, feedback_rx) = channel::oneshot::channel(); + handle2.send_response_with_feedback(request_id, vec![1, 3, 3, 8], feedback_tx); + + assert!(feedback_rx.await.is_err()); } #[tokio::test] async fn custom_timeout_tcp() { - custom_timeout(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())).await; + custom_timeout( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn custom_timeout_quic() { - custom_timeout(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + custom_timeout( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn custom_timeout_websocket() { - custom_timeout( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + custom_timeout( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn custom_timeout(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, _handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = - handle1.try_send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject).unwrap(); - - // verify that the request doesn't timeout after the default timeout - match tokio::time::timeout(Duration::from_secs(5), handle1.next()).await { - Err(_) => {}, - Ok(_) => panic!("expected request to timeout"), - }; - - // verify that the request times out - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::Timeout - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, _handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = + handle1.try_send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject).unwrap(); + + // verify that the request doesn't timeout after the default timeout + match tokio::time::timeout(Duration::from_secs(5), handle1.next()).await { + Err(_) => {} + Ok(_) => panic!("expected request to timeout"), + }; + + // verify that the request times out + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::Timeout + } + ); } #[tokio::test] async fn outbound_request_for_unconnected_peer_tcp() { - outbound_request_for_unconnected_peer(Transport::Tcp(Default::default())).await; + outbound_request_for_unconnected_peer(Transport::Tcp(Default::default())).await; } #[tokio::test] async fn outbound_request_for_unconnected_peer_quic() { - outbound_request_for_unconnected_peer(Transport::Quic(Default::default())).await; + outbound_request_for_unconnected_peer(Transport::Quic(Default::default())).await; } #[tokio::test] async fn outbound_request_for_unconnected_peer_websocket() { - outbound_request_for_unconnected_peer(Transport::WebSocket(Default::default())).await; + outbound_request_for_unconnected_peer(Transport::WebSocket(Default::default())).await; } async fn outbound_request_for_unconnected_peer(transport1: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - tokio::spawn(async move { - let mut litep2p1 = Litep2p::new(config1).unwrap(); - while let Some(_) = litep2p1.next_event().await {} - }); - - let peer2 = PeerId::random(); - let request_id = handle1 - .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) - .await - .unwrap(); - - // verify that the request times out - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer: peer2, - request_id, - error: RequestResponseError::NotConnected - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + tokio::spawn(async move { + let mut litep2p1 = Litep2p::new(config1).unwrap(); + while let Some(_) = litep2p1.next_event().await {} + }); + + let peer2 = PeerId::random(); + let request_id = handle1 + .send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject) + .await + .unwrap(); + + // verify that the request times out + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer: peer2, + request_id, + error: RequestResponseError::NotConnected + } + ); } #[tokio::test] async fn dial_failure_tcp() { - dial_failure(Transport::Tcp(Default::default())).await; + dial_failure(Transport::Tcp(Default::default())).await; } #[tokio::test] async fn dial_failure_quic() { - dial_failure(Transport::Quic(Default::default())).await; + dial_failure(Transport::Quic(Default::default())).await; } #[tokio::test] async fn dial_failure_websocket() { - dial_failure(Transport::WebSocket(Default::default())).await; + dial_failure(Transport::WebSocket(Default::default())).await; } async fn dial_failure(transport: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config, mut handle) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(1024) - .build(); - - let litep2p_config = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config); - - let peer = PeerId::random(); - let known_address = match &transport { - Transport::Tcp(_) => Multiaddr::empty() - .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(5)) - .with(Protocol::P2p(Multihash::from(peer))), - Transport::Quic(_) => Multiaddr::empty() - .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) - .with(Protocol::Udp(5)) - .with(Protocol::QuicV1) - .with(Protocol::P2p(Multihash::from(peer))), - Transport::WebSocket(_) => Multiaddr::empty() - .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) - .with(Protocol::Tcp(5)) - .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) - .with(Protocol::P2p(Multihash::from(peer))), - }; - - let config = match transport { - Transport::Tcp(config) => litep2p_config.with_tcp(config), - Transport::Quic(config) => litep2p_config.with_quic(config), - Transport::WebSocket(config) => litep2p_config.with_websocket(config), - } - .build(); - - let mut litep2p = Litep2p::new(config).unwrap(); - litep2p.add_known_address(peer, vec![known_address].into_iter()); - tokio::spawn(async move { while let Some(_) = litep2p.next_event().await {} }); - - let request_id = handle.send_request(peer, vec![1, 3, 3, 7], DialOptions::Dial).await.unwrap(); - - // verify that the request is reported as rejected since the dial failed - assert_eq!( - handle.next().await.unwrap(), - RequestResponseEvent::RequestFailed { - peer, - request_id, - error: RequestResponseError::Rejected - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config, mut handle) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(1024) + .build(); + + let litep2p_config = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config); + + let peer = PeerId::random(); + let known_address = match &transport { + Transport::Tcp(_) => Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(5)) + .with(Protocol::P2p(Multihash::from(peer))), + Transport::Quic(_) => Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Udp(5)) + .with(Protocol::QuicV1) + .with(Protocol::P2p(Multihash::from(peer))), + Transport::WebSocket(_) => Multiaddr::empty() + .with(Protocol::Ip6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))) + .with(Protocol::Tcp(5)) + .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))) + .with(Protocol::P2p(Multihash::from(peer))), + }; + + let config = match transport { + Transport::Tcp(config) => litep2p_config.with_tcp(config), + Transport::Quic(config) => litep2p_config.with_quic(config), + Transport::WebSocket(config) => litep2p_config.with_websocket(config), + } + .build(); + + let mut litep2p = Litep2p::new(config).unwrap(); + litep2p.add_known_address(peer, vec![known_address].into_iter()); + tokio::spawn(async move { while let Some(_) = litep2p.next_event().await {} }); + + let request_id = handle.send_request(peer, vec![1, 3, 3, 7], DialOptions::Dial).await.unwrap(); + + // verify that the request is reported as rejected since the dial failed + assert_eq!( + handle.next().await.unwrap(), + RequestResponseEvent::RequestFailed { + peer, + request_id, + error: RequestResponseError::Rejected + } + ); } #[tokio::test] async fn large_response_tcp() { - large_response(Transport::Tcp(Default::default()), Transport::Tcp(Default::default())).await; + large_response( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn large_response_quic() { - large_response(Transport::Quic(Default::default()), Transport::Quic(Default::default())).await; + large_response( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn large_response_websocket() { - large_response( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + large_response( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn large_response(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(16 * 1024 * 1024) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(16 * 1024 * 1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - // Generate the response first and use a fast insecure RNG to make the test not timeout on - // GitHub CI when generating 15 MB of data. - let mut rng = XorShiftRng::from_rng(rand::thread_rng()).expect("`thread_rng` to seed"); - let response = (0..15 * 1024 * 1024).map(|_| rng.gen::()).collect::>(); - - let request_id = - handle1.try_send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject).unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 3, 3, 7], - } - ); - - // send response to the received request - handle2.send_response(request_id, response.clone()); - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response, - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(16 * 1024 * 1024) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(16 * 1024 * 1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + // Generate the response first and use a fast insecure RNG to make the test not timeout on + // GitHub CI when generating 15 MB of data. + let mut rng = XorShiftRng::from_rng(rand::thread_rng()).expect("`thread_rng` to seed"); + let response = (0..15 * 1024 * 1024).map(|_| rng.gen::()).collect::>(); + + let request_id = + handle1.try_send_request(peer2, vec![1, 3, 3, 7], DialOptions::Reject).unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 3, 3, 7], + } + ); + + // send response to the received request + handle2.send_response(request_id, response.clone()); + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response, + fallback: None, + } + ); } #[tokio::test] async fn binary_incompatible_fallback_tcp() { - binary_incompatible_fallback( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + binary_incompatible_fallback( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_quic() { - binary_incompatible_fallback( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + binary_incompatible_fallback( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_websocket() { - binary_incompatible_fallback( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + binary_incompatible_fallback( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn binary_incompatible_fallback(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/2")) - .with_max_size(16 * 1024 * 1024) - .with_fallback_names(vec![ProtocolName::from("/protocol/1")]) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(16 * 1024 * 1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle1 - .send_request_with_fallback( - peer2, - vec![1, 2, 3, 4], - (ProtocolName::from("/protocol/1"), vec![5, 6, 7, 8]), - DialOptions::Reject, - ) - .await - .unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![5, 6, 7, 8], - } - ); - - handle2.send_response(request_id, vec![1, 3, 3, 7]); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 7], - fallback: Some(ProtocolName::from("/protocol/1")), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/2")) + .with_max_size(16 * 1024 * 1024) + .with_fallback_names(vec![ProtocolName::from("/protocol/1")]) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(16 * 1024 * 1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle1 + .send_request_with_fallback( + peer2, + vec![1, 2, 3, 4], + (ProtocolName::from("/protocol/1"), vec![5, 6, 7, 8]), + DialOptions::Reject, + ) + .await + .unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![5, 6, 7, 8], + } + ); + + handle2.send_response(request_id, vec![1, 3, 3, 7]); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 7], + fallback: Some(ProtocolName::from("/protocol/1")), + } + ); } #[tokio::test] async fn binary_incompatible_fallback_inbound_request_tcp() { - binary_incompatible_fallback_inbound_request( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + binary_incompatible_fallback_inbound_request( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_inbound_request_quic() { - binary_incompatible_fallback_inbound_request( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + binary_incompatible_fallback_inbound_request( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_inbound_request_websocket() { - binary_incompatible_fallback_inbound_request( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + binary_incompatible_fallback_inbound_request( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn binary_incompatible_fallback_inbound_request( - transport1: Transport, - transport2: Transport, + transport1: Transport, + transport2: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/2")) - .with_max_size(16 * 1024 * 1024) - .with_fallback_names(vec![ProtocolName::from("/protocol/1")]) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) - .with_max_size(16 * 1024 * 1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle2 - .send_request(peer1, vec![1, 2, 3, 4], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer2, - fallback: Some(ProtocolName::from("/protocol/1")), - request_id, - request: vec![1, 2, 3, 4], - } - ); - - handle1.send_response(request_id, vec![1, 3, 3, 8]); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer1, - request_id, - response: vec![1, 3, 3, 8], - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = ConfigBuilder::new(ProtocolName::from("/protocol/2")) + .with_max_size(16 * 1024 * 1024) + .with_fallback_names(vec![ProtocolName::from("/protocol/1")]) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = ConfigBuilder::new(ProtocolName::from("/protocol/1")) + .with_max_size(16 * 1024 * 1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle2 + .send_request(peer1, vec![1, 2, 3, 4], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer2, + fallback: Some(ProtocolName::from("/protocol/1")), + request_id, + request: vec![1, 2, 3, 4], + } + ); + + handle1.send_response(request_id, vec![1, 3, 3, 8]); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer1, + request_id, + response: vec![1, 3, 3, 8], + fallback: None, + } + ); } #[tokio::test] async fn binary_incompatible_fallback_two_fallback_protocols_tcp() { - binary_incompatible_fallback_two_fallback_protocols( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + binary_incompatible_fallback_two_fallback_protocols( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_two_fallback_protocols_quic() { - binary_incompatible_fallback_two_fallback_protocols( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + binary_incompatible_fallback_two_fallback_protocols( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_two_fallback_protocols_websocket() { - binary_incompatible_fallback_two_fallback_protocols( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + binary_incompatible_fallback_two_fallback_protocols( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn binary_incompatible_fallback_two_fallback_protocols( - transport1: Transport, - transport2: Transport, + transport1: Transport, + transport2: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = - ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) - .with_max_size(16 * 1024 * 1024) - .with_fallback_names(vec![ - ProtocolName::from("/genesis/protocol/1"), - ProtocolName::from("/dot/protocol/1"), - ]) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = - ConfigBuilder::new(ProtocolName::from("/genesis/protocol/1")) - .with_fallback_names(vec![ProtocolName::from("/dot/protocol/1")]) - .with_max_size(16 * 1024 * 1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle1 - .send_request_with_fallback( - peer2, - vec![1, 2, 3, 4], - (ProtocolName::from("/genesis/protocol/1"), vec![5, 6, 7, 8]), - DialOptions::Reject, - ) - .await - .unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![5, 6, 7, 8], - } - ); - - handle2.send_response(request_id, vec![1, 3, 3, 7]); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 7], - fallback: Some(ProtocolName::from("/genesis/protocol/1")), - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = + ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) + .with_max_size(16 * 1024 * 1024) + .with_fallback_names(vec![ + ProtocolName::from("/genesis/protocol/1"), + ProtocolName::from("/dot/protocol/1"), + ]) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = + ConfigBuilder::new(ProtocolName::from("/genesis/protocol/1")) + .with_fallback_names(vec![ProtocolName::from("/dot/protocol/1")]) + .with_max_size(16 * 1024 * 1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle1 + .send_request_with_fallback( + peer2, + vec![1, 2, 3, 4], + (ProtocolName::from("/genesis/protocol/1"), vec![5, 6, 7, 8]), + DialOptions::Reject, + ) + .await + .unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![5, 6, 7, 8], + } + ); + + handle2.send_response(request_id, vec![1, 3, 3, 7]); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 7], + fallback: Some(ProtocolName::from("/genesis/protocol/1")), + } + ); } #[tokio::test] async fn binary_incompatible_fallback_two_fallback_protocols_inbound_request_tcp() { - binary_incompatible_fallback_two_fallback_protocols_inbound_request( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + binary_incompatible_fallback_two_fallback_protocols_inbound_request( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_two_fallback_protocols_inbound_request_quic() { - binary_incompatible_fallback_two_fallback_protocols_inbound_request( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + binary_incompatible_fallback_two_fallback_protocols_inbound_request( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_two_fallback_protocols_inbound_request_websocket() { - binary_incompatible_fallback_two_fallback_protocols_inbound_request( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + binary_incompatible_fallback_two_fallback_protocols_inbound_request( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn binary_incompatible_fallback_two_fallback_protocols_inbound_request( - transport1: Transport, - transport2: Transport, + transport1: Transport, + transport2: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = - ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) - .with_max_size(16 * 1024 * 1024) - .with_fallback_names(vec![ - ProtocolName::from("/genesis/protocol/1"), - ProtocolName::from("/dot/protocol/1"), - ]) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = - ConfigBuilder::new(ProtocolName::from("/genesis/protocol/1")) - .with_fallback_names(vec![ProtocolName::from("/dot/protocol/1")]) - .with_max_size(16 * 1024 * 1024) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle2 - .send_request(peer1, vec![1, 2, 3, 4], DialOptions::Reject) - .await - .unwrap(); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer2, - fallback: Some(ProtocolName::from("/genesis/protocol/1")), - request_id, - request: vec![1, 2, 3, 4], - } - ); - - handle1.send_response(request_id, vec![1, 3, 3, 7]); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer1, - request_id, - response: vec![1, 3, 3, 7], - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = + ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) + .with_max_size(16 * 1024 * 1024) + .with_fallback_names(vec![ + ProtocolName::from("/genesis/protocol/1"), + ProtocolName::from("/dot/protocol/1"), + ]) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = + ConfigBuilder::new(ProtocolName::from("/genesis/protocol/1")) + .with_fallback_names(vec![ProtocolName::from("/dot/protocol/1")]) + .with_max_size(16 * 1024 * 1024) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle2 + .send_request(peer1, vec![1, 2, 3, 4], DialOptions::Reject) + .await + .unwrap(); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer2, + fallback: Some(ProtocolName::from("/genesis/protocol/1")), + request_id, + request: vec![1, 2, 3, 4], + } + ); + + handle1.send_response(request_id, vec![1, 3, 3, 7]); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer1, + request_id, + response: vec![1, 3, 3, 7], + fallback: None, + } + ); } #[tokio::test] async fn binary_incompatible_fallback_compatible_nodes_tcp() { - binary_incompatible_fallback_compatible_nodes( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + binary_incompatible_fallback_compatible_nodes( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_compatible_nodes_quic() { - binary_incompatible_fallback_compatible_nodes( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + binary_incompatible_fallback_compatible_nodes( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn binary_incompatible_fallback_compatible_nodes_websocket() { - binary_incompatible_fallback_compatible_nodes( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + binary_incompatible_fallback_compatible_nodes( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } async fn binary_incompatible_fallback_compatible_nodes( - transport1: Transport, - transport2: Transport, + transport1: Transport, + transport2: Transport, ) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (req_resp_config1, mut handle1) = - ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) - .with_max_size(16 * 1024 * 1024) - .with_fallback_names(vec![ - ProtocolName::from("/genesis/protocol/1"), - ProtocolName::from("/dot/protocol/1"), - ]) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config1 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config1); - - let config1 = match transport1 { - Transport::Tcp(config) => config1.with_tcp(config), - Transport::Quic(config) => config1.with_quic(config), - Transport::WebSocket(config) => config1.with_websocket(config), - } - .build(); - - let (req_resp_config2, mut handle2) = - ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) - .with_max_size(16 * 1024 * 1024) - .with_fallback_names(vec![ - ProtocolName::from("/genesis/protocol/1"), - ProtocolName::from("/dot/protocol/1"), - ]) - .with_timeout(Duration::from_secs(8)) - .build(); - - let config2 = Litep2pConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_request_response_protocol(req_resp_config2); - - let config2 = match transport2 { - Transport::Tcp(config) => config2.with_tcp(config), - Transport::Quic(config) => config2.with_quic(config), - Transport::WebSocket(config) => config2.with_websocket(config), - } - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let peer2 = *litep2p2.local_peer_id(); - - // wait until peers have connected - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _ = litep2p2.next_event() => {}, - _ = litep2p1.next_event() => {}, - } - } - }); - - let request_id = handle1 - .send_request_with_fallback( - peer2, - vec![1, 2, 3, 4], - (ProtocolName::from("/genesis/protocol/1"), vec![5, 6, 7, 8]), - DialOptions::Reject, - ) - .await - .unwrap(); - - assert_eq!( - handle2.next().await.unwrap(), - RequestResponseEvent::RequestReceived { - peer: peer1, - fallback: None, - request_id, - request: vec![1, 2, 3, 4], - } - ); - - handle2.send_response(request_id, vec![1, 3, 3, 7]); - - assert_eq!( - handle1.next().await.unwrap(), - RequestResponseEvent::ResponseReceived { - peer: peer2, - request_id, - response: vec![1, 3, 3, 7], - fallback: None, - } - ); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (req_resp_config1, mut handle1) = + ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) + .with_max_size(16 * 1024 * 1024) + .with_fallback_names(vec![ + ProtocolName::from("/genesis/protocol/1"), + ProtocolName::from("/dot/protocol/1"), + ]) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config1 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config1); + + let config1 = match transport1 { + Transport::Tcp(config) => config1.with_tcp(config), + Transport::Quic(config) => config1.with_quic(config), + Transport::WebSocket(config) => config1.with_websocket(config), + } + .build(); + + let (req_resp_config2, mut handle2) = + ConfigBuilder::new(ProtocolName::from("/genesis/protocol/2")) + .with_max_size(16 * 1024 * 1024) + .with_fallback_names(vec![ + ProtocolName::from("/genesis/protocol/1"), + ProtocolName::from("/dot/protocol/1"), + ]) + .with_timeout(Duration::from_secs(8)) + .build(); + + let config2 = Litep2pConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_request_response_protocol(req_resp_config2); + + let config2 = match transport2 { + Transport::Tcp(config) => config2.with_tcp(config), + Transport::Quic(config) => config2.with_quic(config), + Transport::WebSocket(config) => config2.with_websocket(config), + } + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let peer2 = *litep2p2.local_peer_id(); + + // wait until peers have connected + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _ = litep2p2.next_event() => {}, + _ = litep2p1.next_event() => {}, + } + } + }); + + let request_id = handle1 + .send_request_with_fallback( + peer2, + vec![1, 2, 3, 4], + (ProtocolName::from("/genesis/protocol/1"), vec![5, 6, 7, 8]), + DialOptions::Reject, + ) + .await + .unwrap(); + + assert_eq!( + handle2.next().await.unwrap(), + RequestResponseEvent::RequestReceived { + peer: peer1, + fallback: None, + request_id, + request: vec![1, 2, 3, 4], + } + ); + + handle2.send_response(request_id, vec![1, 3, 3, 7]); + + assert_eq!( + handle1.next().await.unwrap(), + RequestResponseEvent::ResponseReceived { + peer: peer2, + request_id, + response: vec![1, 3, 3, 7], + fallback: None, + } + ); } diff --git a/tests/substream.rs b/tests/substream.rs index 91e8a229..5682195b 100644 --- a/tests/substream.rs +++ b/tests/substream.rs @@ -19,559 +19,559 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - codec::ProtocolCodec, - config::ConfigBuilder, - protocol::{Direction, TransportEvent, TransportService, UserProtocol}, - substream::{Substream, SubstreamSet}, - transport::{ - quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, - websocket::config::Config as WebSocketConfig, - }, - types::{protocol::ProtocolName, SubstreamId}, - Error, Litep2p, Litep2pEvent, PeerId, + codec::ProtocolCodec, + config::ConfigBuilder, + protocol::{Direction, TransportEvent, TransportService, UserProtocol}, + substream::{Substream, SubstreamSet}, + transport::{ + quic::config::Config as QuicConfig, tcp::config::Config as TcpConfig, + websocket::config::Config as WebSocketConfig, + }, + types::{protocol::ProtocolName, SubstreamId}, + Error, Litep2p, Litep2pEvent, PeerId, }; use bytes::Bytes; use futures::{Sink, SinkExt, StreamExt}; use tokio::{ - io::AsyncWrite, - sync::{ - mpsc::{channel, Receiver, Sender}, - oneshot, - }, + io::AsyncWrite, + sync::{ + mpsc::{channel, Receiver, Sender}, + oneshot, + }, }; use std::{ - collections::{HashMap, HashSet}, - io::ErrorKind, - sync::Arc, - task::Poll, + collections::{HashMap, HashSet}, + io::ErrorKind, + sync::Arc, + task::Poll, }; enum Transport { - Tcp(TcpConfig), - Quic(QuicConfig), - WebSocket(WebSocketConfig), + Tcp(TcpConfig), + Quic(QuicConfig), + WebSocket(WebSocketConfig), } enum Command { - SendPayloadFramed(PeerId, Vec, oneshot::Sender>), - SendPayloadSink(PeerId, Vec, oneshot::Sender>), - SendPayloadAsyncWrite(PeerId, Vec, oneshot::Sender>), - OpenSubstream(PeerId, oneshot::Sender<()>), + SendPayloadFramed(PeerId, Vec, oneshot::Sender>), + SendPayloadSink(PeerId, Vec, oneshot::Sender>), + SendPayloadAsyncWrite(PeerId, Vec, oneshot::Sender>), + OpenSubstream(PeerId, oneshot::Sender<()>), } struct CustomProtocol { - protocol: ProtocolName, - codec: ProtocolCodec, - peers: HashSet, - rx: Receiver, - pending_opens: HashMap)>, - substreams: SubstreamSet, + protocol: ProtocolName, + codec: ProtocolCodec, + peers: HashSet, + rx: Receiver, + pending_opens: HashMap)>, + substreams: SubstreamSet, } impl CustomProtocol { - pub fn new(codec: ProtocolCodec) -> (Self, Sender) { - let protocol: Arc = Arc::from(String::from("/custom-protocol/1")); - let (tx, rx) = channel(64); - - ( - Self { - peers: HashSet::new(), - protocol: ProtocolName::from(protocol), - codec, - rx, - pending_opens: HashMap::new(), - substreams: SubstreamSet::new(), - }, - tx, - ) - } + pub fn new(codec: ProtocolCodec) -> (Self, Sender) { + let protocol: Arc = Arc::from(String::from("/custom-protocol/1")); + let (tx, rx) = channel(64); + + ( + Self { + peers: HashSet::new(), + protocol: ProtocolName::from(protocol), + codec, + rx, + pending_opens: HashMap::new(), + substreams: SubstreamSet::new(), + }, + tx, + ) + } } #[async_trait::async_trait] impl UserProtocol for CustomProtocol { - fn protocol(&self) -> ProtocolName { - self.protocol.clone() - } - - fn codec(&self) -> ProtocolCodec { - self.codec.clone() - } - - async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { - loop { - tokio::select! { - event = service.next() => match event.unwrap() { - TransportEvent::ConnectionEstablished { peer, .. } => { - self.peers.insert(peer); - } - TransportEvent::ConnectionClosed { peer } => { - self.peers.remove(&peer); - } - TransportEvent::SubstreamOpened { - peer, - substream, - direction, - .. - } => { - self.substreams.insert(peer, substream); - - if let Direction::Outbound(substream_id) = direction { - self.pending_opens.remove(&substream_id).unwrap().1.send(()).unwrap(); - } - } - _ => {} - }, - event = self.substreams.next() => match event { - None => panic!("`SubstreamSet` returned `None`"), - Some((peer, Err(_))) => { - if let Some(mut substream) = self.substreams.remove(&peer) { - futures::future::poll_fn(|cx| { - match futures::ready!(Sink::poll_close(Pin::new(&mut substream), cx)) { - _ => Poll::Ready(()), - } - }).await; - } - } - Some((peer, Ok(_))) => { - if let Some(mut substream) = self.substreams.remove(&peer) { - futures::future::poll_fn(|cx| { - match futures::ready!(Sink::poll_close(Pin::new(&mut substream), cx)) { - _ => Poll::Ready(()), - } - }).await; - } - }, - }, - command = self.rx.recv() => match command.unwrap() { - Command::SendPayloadFramed(peer, payload, tx) => { - match self.substreams.remove(&peer) { - None => { - tx.send(Err(Error::PeerDoesntExist(peer))).unwrap(); - } - Some(mut substream) => { - let payload = Bytes::from(payload); - let res = substream.send_framed(payload).await; - tx.send(res).unwrap(); - let _ = substream.close().await; - } - } - } - Command::SendPayloadSink(peer, payload, tx) => { - match self.substreams.remove(&peer) { - None => { - tx.send(Err(Error::PeerDoesntExist(peer))).unwrap(); - } - Some(mut substream) => { - let payload = Bytes::from(payload); - let res = substream.send(payload).await; - tx.send(res).unwrap(); - let _ = substream.close().await; - } - } - } - Command::SendPayloadAsyncWrite(peer, payload, tx) => { - match self.substreams.remove(&peer) { - None => { - tx.send(Err(Error::PeerDoesntExist(peer))).unwrap(); - } - Some(mut substream) => { - let res = futures::future::poll_fn(|cx| { - if let Err(error) = futures::ready!(Pin::new(&mut substream).poll_write(cx, &payload)) { - return Poll::Ready(Err(error.into())); - } - - if let Err(error) = futures::ready!(tokio::io::AsyncWrite::poll_flush( - Pin::new(&mut substream), - cx - )) { - return Poll::Ready(Err(error.into())); - } - - if let Err(error) = futures::ready!(tokio::io::AsyncWrite::poll_shutdown( - Pin::new(&mut substream), - cx - )) { - return Poll::Ready(Err(error.into())); - } - - Poll::Ready(Ok(())) - }) - .await; - tx.send(res).unwrap(); - } - } - } - Command::OpenSubstream(peer, tx) => { - let substream_id = service.open_substream(peer).unwrap(); - self.pending_opens.insert(substream_id, (peer, tx)); - } - } - } - } - } + fn protocol(&self) -> ProtocolName { + self.protocol.clone() + } + + fn codec(&self) -> ProtocolCodec { + self.codec.clone() + } + + async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { + loop { + tokio::select! { + event = service.next() => match event.unwrap() { + TransportEvent::ConnectionEstablished { peer, .. } => { + self.peers.insert(peer); + } + TransportEvent::ConnectionClosed { peer } => { + self.peers.remove(&peer); + } + TransportEvent::SubstreamOpened { + peer, + substream, + direction, + .. + } => { + self.substreams.insert(peer, substream); + + if let Direction::Outbound(substream_id) = direction { + self.pending_opens.remove(&substream_id).unwrap().1.send(()).unwrap(); + } + } + _ => {} + }, + event = self.substreams.next() => match event { + None => panic!("`SubstreamSet` returned `None`"), + Some((peer, Err(_))) => { + if let Some(mut substream) = self.substreams.remove(&peer) { + futures::future::poll_fn(|cx| { + match futures::ready!(Sink::poll_close(Pin::new(&mut substream), cx)) { + _ => Poll::Ready(()), + } + }).await; + } + } + Some((peer, Ok(_))) => { + if let Some(mut substream) = self.substreams.remove(&peer) { + futures::future::poll_fn(|cx| { + match futures::ready!(Sink::poll_close(Pin::new(&mut substream), cx)) { + _ => Poll::Ready(()), + } + }).await; + } + }, + }, + command = self.rx.recv() => match command.unwrap() { + Command::SendPayloadFramed(peer, payload, tx) => { + match self.substreams.remove(&peer) { + None => { + tx.send(Err(Error::PeerDoesntExist(peer))).unwrap(); + } + Some(mut substream) => { + let payload = Bytes::from(payload); + let res = substream.send_framed(payload).await; + tx.send(res).unwrap(); + let _ = substream.close().await; + } + } + } + Command::SendPayloadSink(peer, payload, tx) => { + match self.substreams.remove(&peer) { + None => { + tx.send(Err(Error::PeerDoesntExist(peer))).unwrap(); + } + Some(mut substream) => { + let payload = Bytes::from(payload); + let res = substream.send(payload).await; + tx.send(res).unwrap(); + let _ = substream.close().await; + } + } + } + Command::SendPayloadAsyncWrite(peer, payload, tx) => { + match self.substreams.remove(&peer) { + None => { + tx.send(Err(Error::PeerDoesntExist(peer))).unwrap(); + } + Some(mut substream) => { + let res = futures::future::poll_fn(|cx| { + if let Err(error) = futures::ready!(Pin::new(&mut substream).poll_write(cx, &payload)) { + return Poll::Ready(Err(error.into())); + } + + if let Err(error) = futures::ready!(tokio::io::AsyncWrite::poll_flush( + Pin::new(&mut substream), + cx + )) { + return Poll::Ready(Err(error.into())); + } + + if let Err(error) = futures::ready!(tokio::io::AsyncWrite::poll_shutdown( + Pin::new(&mut substream), + cx + )) { + return Poll::Ready(Err(error.into())); + } + + Poll::Ready(Ok(())) + }) + .await; + tx.send(res).unwrap(); + } + } + } + Command::OpenSubstream(peer, tx) => { + let substream_id = service.open_substream(peer).unwrap(); + self.pending_opens.insert(substream_id, (peer, tx)); + } + } + } + } + } } async fn connect_peers(litep2p1: &mut Litep2p, litep2p2: &mut Litep2p) { - let listen_address = litep2p1.listen_addresses().next().unwrap().clone(); - litep2p2.dial_address(listen_address).await.unwrap(); - - let mut litep2p1_ready = false; - let mut litep2p2_ready = false; - - while !litep2p1_ready && !litep2p2_ready { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => litep2p1_ready = true, - _ => {} - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => litep2p2_ready = true, - _ => {} - }, - } - } + let listen_address = litep2p1.listen_addresses().next().unwrap().clone(); + litep2p2.dial_address(listen_address).await.unwrap(); + + let mut litep2p1_ready = false; + let mut litep2p2_ready = false; + + while !litep2p1_ready && !litep2p2_ready { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => litep2p1_ready = true, + _ => {} + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => litep2p2_ready = true, + _ => {} + }, + } + } } #[tokio::test] async fn too_big_identity_payload_framed_tcp() { - too_big_identity_payload_framed( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + too_big_identity_payload_framed( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn too_big_identity_payload_framed_quic() { - too_big_identity_payload_framed( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + too_big_identity_payload_framed( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn too_big_identity_payload_framed_websocket() { - too_big_identity_payload_framed( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + too_big_identity_payload_framed( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } // send too big payload using `Substream::send_framed()` and verify it's rejected async fn too_big_identity_payload_framed(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config1 = match transport1 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol1)) - .build(); - - let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config2 = match transport2 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol2)) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // connect peers and start event loops for litep2ps - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _event = litep2p1.next_event() => {} - _event = litep2p2.next_event() => {} - } - } - }); - tokio::time::sleep(std::time::Duration::from_millis(1000)).await; - - // open substream to peer - let (tx, rx) = oneshot::channel(); - tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); - - let Ok(()) = rx.await else { - panic!("failed to open substream"); - }; - - // send too large paylod to peer - let (tx, rx) = oneshot::channel(); - tx1.send(Command::SendPayloadFramed(peer2, vec![0u8; 16], tx)).await.unwrap(); - - match rx.await { - Ok(Err(Error::IoError(ErrorKind::PermissionDenied))) => {}, - event => panic!("invalid event received: {event:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config1 = match transport1 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol1)) + .build(); + + let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config2 = match transport2 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol2)) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // connect peers and start event loops for litep2ps + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _event = litep2p1.next_event() => {} + _event = litep2p2.next_event() => {} + } + } + }); + tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + + // open substream to peer + let (tx, rx) = oneshot::channel(); + tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); + + let Ok(()) = rx.await else { + panic!("failed to open substream"); + }; + + // send too large paylod to peer + let (tx, rx) = oneshot::channel(); + tx1.send(Command::SendPayloadFramed(peer2, vec![0u8; 16], tx)).await.unwrap(); + + match rx.await { + Ok(Err(Error::IoError(ErrorKind::PermissionDenied))) => {} + event => panic!("invalid event received: {event:?}"), + } } #[tokio::test] async fn too_big_identity_payload_sink_tcp() { - too_big_identity_payload_sink( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + too_big_identity_payload_sink( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn too_big_identity_payload_sink_quic() { - too_big_identity_payload_sink( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + too_big_identity_payload_sink( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn too_big_identity_payload_sink_websocket() { - too_big_identity_payload_sink( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + too_big_identity_payload_sink( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } // send too big payload using `::send()` and verify it's rejected async fn too_big_identity_payload_sink(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config1 = match transport1 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol1)) - .build(); - - let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config2 = match transport2 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol2)) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // connect peers and start event loops for litep2ps - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _event = litep2p1.next_event() => {} - _event = litep2p2.next_event() => {} - } - } - }); - tokio::time::sleep(std::time::Duration::from_millis(1000)).await; - - { - // open substream to peer - let (tx, rx) = oneshot::channel(); - tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); - - let Ok(()) = rx.await else { - panic!("failed to open substream"); - }; - - // send too large paylod to peer - let (tx, rx) = oneshot::channel(); - tx1.send(Command::SendPayloadSink(peer2, vec![0u8; 16], tx)).await.unwrap(); - - match rx.await { - Ok(Err(Error::IoError(ErrorKind::PermissionDenied))) => {}, - event => panic!("invalid event received: {event:?}"), - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config1 = match transport1 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol1)) + .build(); + + let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config2 = match transport2 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol2)) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // connect peers and start event loops for litep2ps + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _event = litep2p1.next_event() => {} + _event = litep2p2.next_event() => {} + } + } + }); + tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + + { + // open substream to peer + let (tx, rx) = oneshot::channel(); + tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); + + let Ok(()) = rx.await else { + panic!("failed to open substream"); + }; + + // send too large paylod to peer + let (tx, rx) = oneshot::channel(); + tx1.send(Command::SendPayloadSink(peer2, vec![0u8; 16], tx)).await.unwrap(); + + match rx.await { + Ok(Err(Error::IoError(ErrorKind::PermissionDenied))) => {} + event => panic!("invalid event received: {event:?}"), + } + } } #[tokio::test] async fn correct_payload_size_sink_tcp() { - correct_payload_size_sink( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + correct_payload_size_sink( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn correct_payload_size_sink_quic() { - correct_payload_size_sink( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + correct_payload_size_sink( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn correct_payload_size_sink_websocket() { - correct_payload_size_sink( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + correct_payload_size_sink( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } // send correctly-sized payload using `::send()` async fn correct_payload_size_sink(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config1 = match transport1 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol1)) - .build(); - - let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config2 = match transport2 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol2)) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // connect peers and start event loops for litep2ps - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _event = litep2p1.next_event() => {} - _event = litep2p2.next_event() => {} - } - } - }); - tokio::time::sleep(std::time::Duration::from_millis(1000)).await; - - // open substream to peer - let (tx, rx) = oneshot::channel(); - tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); - - let Ok(()) = rx.await else { - panic!("failed to open substream"); - }; - - let (tx, rx) = oneshot::channel(); - tx1.send(Command::SendPayloadSink(peer2, vec![0u8; 10], tx)).await.unwrap(); - - match rx.await { - Ok(_) => {}, - event => panic!("invalid event received: {event:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config1 = match transport1 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol1)) + .build(); + + let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config2 = match transport2 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol2)) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // connect peers and start event loops for litep2ps + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _event = litep2p1.next_event() => {} + _event = litep2p2.next_event() => {} + } + } + }); + tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + + // open substream to peer + let (tx, rx) = oneshot::channel(); + tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); + + let Ok(()) = rx.await else { + panic!("failed to open substream"); + }; + + let (tx, rx) = oneshot::channel(); + tx1.send(Command::SendPayloadSink(peer2, vec![0u8; 10], tx)).await.unwrap(); + + match rx.await { + Ok(_) => {} + event => panic!("invalid event received: {event:?}"), + } } #[tokio::test] async fn correct_payload_size_async_write_tcp() { - correct_payload_size_async_write( - Transport::Tcp(Default::default()), - Transport::Tcp(Default::default()), - ) - .await; + correct_payload_size_async_write( + Transport::Tcp(Default::default()), + Transport::Tcp(Default::default()), + ) + .await; } #[tokio::test] async fn correct_payload_size_async_write_quic() { - correct_payload_size_async_write( - Transport::Quic(Default::default()), - Transport::Quic(Default::default()), - ) - .await; + correct_payload_size_async_write( + Transport::Quic(Default::default()), + Transport::Quic(Default::default()), + ) + .await; } #[tokio::test] async fn correct_payload_size_async_write_websocket() { - correct_payload_size_async_write( - Transport::WebSocket(Default::default()), - Transport::WebSocket(Default::default()), - ) - .await; + correct_payload_size_async_write( + Transport::WebSocket(Default::default()), + Transport::WebSocket(Default::default()), + ) + .await; } // send correctly-sized payload using `::poll_write()` async fn correct_payload_size_async_write(transport1: Transport, transport2: Transport) { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config1 = match transport1 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol1)) - .build(); - - let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); - let config2 = match transport2 { - Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), - Transport::Quic(config) => ConfigBuilder::new().with_quic(config), - Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), - } - .with_user_protocol(Box::new(custom_protocol2)) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let peer2 = *litep2p2.local_peer_id(); - - // connect peers and start event loops for litep2ps - connect_peers(&mut litep2p1, &mut litep2p2).await; - tokio::spawn(async move { - loop { - tokio::select! { - _event = litep2p1.next_event() => {} - _event = litep2p2.next_event() => {} - } - } - }); - tokio::time::sleep(std::time::Duration::from_millis(1000)).await; - - // open substream to peer - let (tx, rx) = oneshot::channel(); - tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); - - let Ok(()) = rx.await else { - panic!("failed to open substream"); - }; - - let (tx, rx) = oneshot::channel(); - tx1.send(Command::SendPayloadAsyncWrite(peer2, vec![0u8; 10], tx)) - .await - .unwrap(); - - match rx.await { - Ok(_) => {}, - event => panic!("invalid event received: {event:?}"), - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (custom_protocol1, tx1) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config1 = match transport1 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol1)) + .build(); + + let (custom_protocol2, _tx2) = CustomProtocol::new(ProtocolCodec::Identity(10usize)); + let config2 = match transport2 { + Transport::Tcp(config) => ConfigBuilder::new().with_tcp(config), + Transport::Quic(config) => ConfigBuilder::new().with_quic(config), + Transport::WebSocket(config) => ConfigBuilder::new().with_websocket(config), + } + .with_user_protocol(Box::new(custom_protocol2)) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let peer2 = *litep2p2.local_peer_id(); + + // connect peers and start event loops for litep2ps + connect_peers(&mut litep2p1, &mut litep2p2).await; + tokio::spawn(async move { + loop { + tokio::select! { + _event = litep2p1.next_event() => {} + _event = litep2p2.next_event() => {} + } + } + }); + tokio::time::sleep(std::time::Duration::from_millis(1000)).await; + + // open substream to peer + let (tx, rx) = oneshot::channel(); + tx1.send(Command::OpenSubstream(peer2, tx)).await.unwrap(); + + let Ok(()) = rx.await else { + panic!("failed to open substream"); + }; + + let (tx, rx) = oneshot::channel(); + tx1.send(Command::SendPayloadAsyncWrite(peer2, vec![0u8; 10], tx)) + .await + .unwrap(); + + match rx.await { + Ok(_) => {} + event => panic!("invalid event received: {event:?}"), + } } diff --git a/tests/user_protocol.rs b/tests/user_protocol.rs index be3f764c..7ea23428 100644 --- a/tests/user_protocol.rs +++ b/tests/user_protocol.rs @@ -19,13 +19,13 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - codec::ProtocolCodec, - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::{mdns::Config as MdnsConfig, TransportEvent, TransportService, UserProtocol}, - transport::tcp::config::Config as TcpConfig, - types::protocol::ProtocolName, - Litep2p, PeerId, + codec::ProtocolCodec, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::{mdns::Config as MdnsConfig, TransportEvent, TransportService, UserProtocol}, + transport::tcp::config::Config as TcpConfig, + types::protocol::ProtocolName, + Litep2p, PeerId, }; use futures::StreamExt; @@ -33,118 +33,122 @@ use futures::StreamExt; use std::{collections::HashSet, sync::Arc, time::Duration}; struct CustomProtocol { - protocol: ProtocolName, - codec: ProtocolCodec, - peers: HashSet, + protocol: ProtocolName, + codec: ProtocolCodec, + peers: HashSet, } impl CustomProtocol { - pub fn new() -> Self { - let protocol: Arc = Arc::from(String::from("/custom-protocol/1")); - - Self { - peers: HashSet::new(), - protocol: ProtocolName::from(protocol), - codec: ProtocolCodec::UnsignedVarint(None), - } - } + pub fn new() -> Self { + let protocol: Arc = Arc::from(String::from("/custom-protocol/1")); + + Self { + peers: HashSet::new(), + protocol: ProtocolName::from(protocol), + codec: ProtocolCodec::UnsignedVarint(None), + } + } } #[async_trait::async_trait] impl UserProtocol for CustomProtocol { - fn protocol(&self) -> ProtocolName { - self.protocol.clone() - } - - fn codec(&self) -> ProtocolCodec { - self.codec.clone() - } - - async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { - loop { - while let Some(event) = service.next().await { - tracing::trace!("received event: {event:?}"); - - match event { - TransportEvent::ConnectionEstablished { peer, .. } => { - self.peers.insert(peer); - }, - TransportEvent::ConnectionClosed { peer } => { - self.peers.remove(&peer); - }, - _ => {}, - } - } - } - } + fn protocol(&self) -> ProtocolName { + self.protocol.clone() + } + + fn codec(&self) -> ProtocolCodec { + self.codec.clone() + } + + async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { + loop { + while let Some(event) = service.next().await { + tracing::trace!("received event: {event:?}"); + + match event { + TransportEvent::ConnectionEstablished { peer, .. } => { + self.peers.insert(peer); + } + TransportEvent::ConnectionClosed { peer } => { + self.peers.remove(&peer); + } + _ => {} + } + } + } + } } #[tokio::test] async fn user_protocol() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let custom_protocol1 = Box::new(CustomProtocol::new()); - let (mdns_config, _stream) = MdnsConfig::new(Duration::from_secs(30)); - - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { ..Default::default() }) - .with_user_protocol(custom_protocol1) - .with_mdns(mdns_config) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let peer1 = *litep2p1.local_peer_id(); - let listen_address = litep2p1.listen_addresses().next().unwrap().clone(); - - let custom_protocol2 = Box::new(CustomProtocol::new()); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { ..Default::default() }) - .with_user_protocol(custom_protocol2) - .with_known_addresses(vec![(peer1, vec![listen_address])].into_iter()) - .with_max_parallel_dials(8usize) - .build(); - - let mut litep2p2 = Litep2p::new(config2).unwrap(); - litep2p2.dial(&peer1).await.unwrap(); - - // wait until connection is established - let mut litep2p1_ready = false; - let mut litep2p2_ready = false; - - while !litep2p1_ready && !litep2p2_ready { - tokio::select! { - event = litep2p1.next_event() => { - tracing::trace!("litep2p1 event: {event:?}"); - litep2p1_ready = true; - } - event = litep2p2.next_event() => { - tracing::trace!("litep2p2 event: {event:?}"); - litep2p2_ready = true; - } - } - } - - // wait until connection is closed by the keep-alive timeout - let mut litep2p1_ready = false; - let mut litep2p2_ready = false; - - while !litep2p1_ready && !litep2p2_ready { - tokio::select! { - event = litep2p1.next_event() => { - tracing::trace!("litep2p1 event: {event:?}"); - litep2p1_ready = true; - } - event = litep2p2.next_event() => { - tracing::trace!("litep2p2 event: {event:?}"); - litep2p2_ready = true; - } - } - } - - let sink = litep2p2.bandwidth_sink(); - tracing::trace!("inbound {}, outbound {}", sink.outbound(), sink.inbound()); + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let custom_protocol1 = Box::new(CustomProtocol::new()); + let (mdns_config, _stream) = MdnsConfig::new(Duration::from_secs(30)); + + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + ..Default::default() + }) + .with_user_protocol(custom_protocol1) + .with_mdns(mdns_config) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let peer1 = *litep2p1.local_peer_id(); + let listen_address = litep2p1.listen_addresses().next().unwrap().clone(); + + let custom_protocol2 = Box::new(CustomProtocol::new()); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + ..Default::default() + }) + .with_user_protocol(custom_protocol2) + .with_known_addresses(vec![(peer1, vec![listen_address])].into_iter()) + .with_max_parallel_dials(8usize) + .build(); + + let mut litep2p2 = Litep2p::new(config2).unwrap(); + litep2p2.dial(&peer1).await.unwrap(); + + // wait until connection is established + let mut litep2p1_ready = false; + let mut litep2p2_ready = false; + + while !litep2p1_ready && !litep2p2_ready { + tokio::select! { + event = litep2p1.next_event() => { + tracing::trace!("litep2p1 event: {event:?}"); + litep2p1_ready = true; + } + event = litep2p2.next_event() => { + tracing::trace!("litep2p2 event: {event:?}"); + litep2p2_ready = true; + } + } + } + + // wait until connection is closed by the keep-alive timeout + let mut litep2p1_ready = false; + let mut litep2p2_ready = false; + + while !litep2p1_ready && !litep2p2_ready { + tokio::select! { + event = litep2p1.next_event() => { + tracing::trace!("litep2p1 event: {event:?}"); + litep2p1_ready = true; + } + event = litep2p2.next_event() => { + tracing::trace!("litep2p2 event: {event:?}"); + litep2p2_ready = true; + } + } + } + + let sink = litep2p2.bandwidth_sink(); + tracing::trace!("inbound {}, outbound {}", sink.outbound(), sink.inbound()); } diff --git a/tests/user_protocol_2.rs b/tests/user_protocol_2.rs index b517286e..d2e0ce6d 100644 --- a/tests/user_protocol_2.rs +++ b/tests/user_protocol_2.rs @@ -19,13 +19,13 @@ // DEALINGS IN THE SOFTWARE. use litep2p::{ - codec::ProtocolCodec, - config::ConfigBuilder, - crypto::ed25519::Keypair, - protocol::{TransportEvent, TransportService, UserProtocol}, - transport::tcp::config::Config as TcpConfig, - types::protocol::ProtocolName, - Litep2p, Litep2pEvent, PeerId, + codec::ProtocolCodec, + config::ConfigBuilder, + crypto::ed25519::Keypair, + protocol::{TransportEvent, TransportService, UserProtocol}, + transport::tcp::config::Config as TcpConfig, + types::protocol::ProtocolName, + Litep2p, Litep2pEvent, PeerId, }; use futures::StreamExt; @@ -35,110 +35,114 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::collections::HashSet; struct CustomProtocol { - protocol: ProtocolName, - codec: ProtocolCodec, - peers: HashSet, - rx: Receiver, + protocol: ProtocolName, + codec: ProtocolCodec, + peers: HashSet, + rx: Receiver, } impl CustomProtocol { - pub fn new() -> (Self, Sender) { - let (tx, rx) = channel(64); - - ( - Self { - rx, - peers: HashSet::new(), - protocol: ProtocolName::from("/custom-protocol/1"), - codec: ProtocolCodec::UnsignedVarint(None), - }, - tx, - ) - } + pub fn new() -> (Self, Sender) { + let (tx, rx) = channel(64); + + ( + Self { + rx, + peers: HashSet::new(), + protocol: ProtocolName::from("/custom-protocol/1"), + codec: ProtocolCodec::UnsignedVarint(None), + }, + tx, + ) + } } #[async_trait::async_trait] impl UserProtocol for CustomProtocol { - fn protocol(&self) -> ProtocolName { - self.protocol.clone() - } - - fn codec(&self) -> ProtocolCodec { - self.codec.clone() - } - - async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { - loop { - tokio::select! { - event = service.next() => match event.unwrap() { - TransportEvent::ConnectionEstablished { peer, .. } => { - self.peers.insert(peer); - } - TransportEvent::ConnectionClosed { peer: _ } => {} - TransportEvent::SubstreamOpened { - peer: _, - protocol: _, - direction: _, - substream: _, - fallback: _, - } => {} - TransportEvent::SubstreamOpenFailure { - substream: _, - error: _, - } => {} - TransportEvent::DialFailure { .. } => {} - }, - address = self.rx.recv() => { - service.dial_address(address.unwrap()).unwrap(); - } - } - } - } + fn protocol(&self) -> ProtocolName { + self.protocol.clone() + } + + fn codec(&self) -> ProtocolCodec { + self.codec.clone() + } + + async fn run(mut self: Box, mut service: TransportService) -> litep2p::Result<()> { + loop { + tokio::select! { + event = service.next() => match event.unwrap() { + TransportEvent::ConnectionEstablished { peer, .. } => { + self.peers.insert(peer); + } + TransportEvent::ConnectionClosed { peer: _ } => {} + TransportEvent::SubstreamOpened { + peer: _, + protocol: _, + direction: _, + substream: _, + fallback: _, + } => {} + TransportEvent::SubstreamOpenFailure { + substream: _, + error: _, + } => {} + TransportEvent::DialFailure { .. } => {} + }, + address = self.rx.recv() => { + service.dial_address(address.unwrap()).unwrap(); + } + } + } + } } #[tokio::test] async fn user_protocol_2() { - let _ = tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) - .try_init(); - - let (custom_protocol1, sender1) = CustomProtocol::new(); - let config1 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { ..Default::default() }) - .with_user_protocol(Box::new(custom_protocol1)) - .build(); - - let (custom_protocol2, _sender2) = CustomProtocol::new(); - let config2 = ConfigBuilder::new() - .with_keypair(Keypair::generate()) - .with_tcp(TcpConfig { ..Default::default() }) - .with_user_protocol(Box::new(custom_protocol2)) - .build(); - - let mut litep2p1 = Litep2p::new(config1).unwrap(); - let mut litep2p2 = Litep2p::new(config2).unwrap(); - let address = litep2p2.listen_addresses().next().unwrap().clone(); - - sender1.send(address).await.unwrap(); - - let mut litep2p1_ready = false; - let mut litep2p2_ready = false; - - while !litep2p1_ready && !litep2p2_ready { - tokio::select! { - event = litep2p1.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p1_ready = true; - } - _ => {} - }, - event = litep2p2.next_event() => match event.unwrap() { - Litep2pEvent::ConnectionEstablished { .. } => { - litep2p2_ready = true; - } - _ => {} - } - } - } + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (custom_protocol1, sender1) = CustomProtocol::new(); + let config1 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + ..Default::default() + }) + .with_user_protocol(Box::new(custom_protocol1)) + .build(); + + let (custom_protocol2, _sender2) = CustomProtocol::new(); + let config2 = ConfigBuilder::new() + .with_keypair(Keypair::generate()) + .with_tcp(TcpConfig { + ..Default::default() + }) + .with_user_protocol(Box::new(custom_protocol2)) + .build(); + + let mut litep2p1 = Litep2p::new(config1).unwrap(); + let mut litep2p2 = Litep2p::new(config2).unwrap(); + let address = litep2p2.listen_addresses().next().unwrap().clone(); + + sender1.send(address).await.unwrap(); + + let mut litep2p1_ready = false; + let mut litep2p2_ready = false; + + while !litep2p1_ready && !litep2p2_ready { + tokio::select! { + event = litep2p1.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p1_ready = true; + } + _ => {} + }, + event = litep2p2.next_event() => match event.unwrap() { + Litep2pEvent::ConnectionEstablished { .. } => { + litep2p2_ready = true; + } + _ => {} + } + } + } }