Skip to main content

hotshot_libp2p_networking/network/
transport.rs

1use std::{
2    future::Future,
3    io::{Error as IoError, ErrorKind as IoErrorKind},
4    pin::Pin,
5    sync::Arc,
6    task::Poll,
7};
8
9use anyhow::{Context, Result as AnyhowResult, ensure};
10use bimap::BiMap;
11use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, future::poll_fn};
12use hotshot_types::traits::signature_key::SignatureKey;
13use libp2p::{
14    Transport,
15    core::{
16        StreamMuxer,
17        muxing::StreamMuxerExt,
18        transport::{DialOpts, TransportEvent},
19    },
20    identity::PeerId,
21};
22use parking_lot::Mutex;
23use pin_project::pin_project;
24use serde::{Deserialize, Serialize};
25use tokio::time::timeout;
26use tracing::debug;
27
28use crate::network::log_summary::LogEvent;
29
30/// The maximum size of an authentication message. This is used to prevent
31/// DoS attacks by sending large messages.
32const MAX_AUTH_MESSAGE_SIZE: usize = 1024;
33
34/// The timeout for the authentication handshake. This is used to prevent
35/// attacks that keep connections open indefinitely by half-finishing the
36/// handshake.
37const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
38
39/// A wrapper for a `Transport` that bidirectionally associates (and verifies)
40/// the corresponding consensus keys.
41#[pin_project]
42pub struct ConsensusKeyAuthentication<
43    T: Transport,
44    S: SignatureKey + 'static,
45    C: StreamMuxer + Unpin,
46> {
47    #[pin]
48    /// The underlying transport we are wrapping
49    pub inner: T,
50
51    /// A pre-signed message that we (depending on if it's specified or not) send to the remote peer for authentication
52    pub auth_message: Arc<Option<Vec<u8>>>,
53
54    /// The (verified) map of consensus keys to peer IDs
55    pub consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
56
57    /// Phantom data for the connection type
58    pd: std::marker::PhantomData<(C, S)>,
59}
60
61/// A type alias for the future that upgrades a connection to perform the authentication handshake
62type UpgradeFuture<T> =
63    Pin<Box<dyn Future<Output = Result<<T as Transport>::Output, <T as Transport>::Error>> + Send>>;
64
65impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin>
66    ConsensusKeyAuthentication<T, S, C>
67{
68    /// Create a new `ConsensusKeyAuthentication` transport that wraps the given transport
69    /// and authenticates connections against the stake table. If the auth message is `None`,
70    /// the authentication is disabled.
71    pub fn new(
72        inner: T,
73        auth_message: Option<Vec<u8>>,
74        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
75    ) -> Self {
76        Self {
77            inner,
78            auth_message: Arc::from(auth_message),
79            consensus_key_to_pid_map,
80            pd: std::marker::PhantomData,
81        }
82    }
83
84    /// Prove to the remote peer that we are in the stake table by sending
85    /// them our authentication message.
86    ///
87    /// # Errors
88    /// - If we fail to write the message to the stream
89    pub async fn authenticate_with_remote_peer<W: AsyncWrite + Unpin>(
90        stream: &mut W,
91        auth_message: &[u8],
92    ) -> AnyhowResult<()> {
93        // Write the length-delimited message
94        write_length_delimited(stream, auth_message).await?;
95
96        Ok(())
97    }
98
99    /// Verify that the remote peer is:
100    /// - In the stake table
101    /// - Sending us a valid authentication message
102    /// - Sending us a valid signature
103    /// - Matching the peer ID we expect
104    ///
105    /// # Errors
106    /// If the peer fails verification. This can happen if:
107    /// - We fail to read the message from the stream
108    /// - The message is too large
109    /// - The message is invalid
110    /// - The peer is not in the stake table
111    /// - The signature is invalid
112    pub async fn verify_peer_authentication<R: AsyncReadExt + Unpin>(
113        stream: &mut R,
114        required_peer_id: &PeerId,
115        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
116    ) -> AnyhowResult<()> {
117        // Read the length-delimited message from the remote peer
118        let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?;
119
120        // Deserialize the authentication message
121        let auth_message: AuthMessage<S> =
122            bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?;
123
124        // Verify the signature on the public keys
125        let public_key = auth_message
126            .validate()
127            .with_context(|| "Failed to verify authentication message")?;
128
129        // Deserialize the `PeerId`
130        let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes)
131            .with_context(|| "Failed to deserialize peer ID")?;
132
133        // Verify that the peer ID is the same as the remote peer
134        if peer_id != *required_peer_id {
135            return Err(anyhow::anyhow!("Peer ID mismatch"));
136        }
137
138        // If we got here, the peer is authenticated. Add the consensus key to the map
139        consensus_key_to_pid_map.lock().insert(public_key, peer_id);
140
141        Ok(())
142    }
143
144    /// Wrap the supplied future in an upgrade that performs the authentication handshake.
145    ///
146    /// `outgoing` is a boolean that indicates if the connection is incoming or outgoing.
147    /// This is needed because the flow of the handshake is different for each.
148    fn gen_handshake<F: Future<Output = Result<T::Output, T::Error>> + Send + 'static>(
149        original_future: F,
150        outgoing: bool,
151        auth_message: Arc<Option<Vec<u8>>>,
152        consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
153    ) -> UpgradeFuture<T>
154    where
155        T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
156        T::Output: AsOutput<C> + Send,
157        C::Substream: Unpin + Send,
158    {
159        // Create a new upgrade that performs the authentication handshake on top
160        Box::pin(async move {
161            // Wait for the original future to resolve
162            let mut stream = original_future.await?;
163
164            // Time out the authentication block
165            timeout(AUTH_HANDSHAKE_TIMEOUT, async {
166                // Open a substream for the handshake.
167                // The handshake order depends on whether the connection is incoming or outgoing.
168                let mut substream = if outgoing {
169                    poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?
170                } else {
171                    poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?
172                };
173
174                // Conditionally authenticate depending on whether we specified an auth message
175                if let Some(auth_message) = auth_message.as_ref() {
176                    if outgoing {
177                        // If the connection is outgoing, authenticate with the remote peer first
178                        Self::authenticate_with_remote_peer(&mut substream, auth_message)
179                            .await
180                            .map_err(|e| {
181                                LogEvent::AuthFailure.record();
182                                debug!("Failed to authenticate with remote peer: {e:?}");
183                                IoError::other(e)
184                            })?;
185
186                        // Verify the remote peer's authentication
187                        Self::verify_peer_authentication(
188                            &mut substream,
189                            stream.as_peer_id(),
190                            consensus_key_to_pid_map,
191                        )
192                        .await
193                        .map_err(|e| {
194                            LogEvent::VerifyFailure.record();
195                            debug!("Failed to verify remote peer: {e:?}");
196                            IoError::other(e)
197                        })?;
198                    } else {
199                        // If it is incoming, verify the remote peer's authentication first
200                        Self::verify_peer_authentication(
201                            &mut substream,
202                            stream.as_peer_id(),
203                            consensus_key_to_pid_map,
204                        )
205                        .await
206                        .map_err(|e| {
207                            LogEvent::VerifyFailure.record();
208                            debug!("Failed to verify remote peer: {e:?}");
209                            IoError::other(e)
210                        })?;
211
212                        // Authenticate with the remote peer
213                        Self::authenticate_with_remote_peer(&mut substream, auth_message)
214                            .await
215                            .map_err(|e| {
216                                LogEvent::AuthFailure.record();
217                                debug!("Failed to authenticate with remote peer: {e:?}");
218                                IoError::other(e)
219                            })?;
220                    }
221                }
222
223                Ok(stream)
224            })
225            .await
226            .map_err(|e| {
227                LogEvent::AuthHandshakeTimeout.record();
228                debug!("Timed out performing authentication handshake: {e:?}");
229                IoError::new(IoErrorKind::TimedOut, e)
230            })?
231        })
232    }
233}
234
235/// The deserialized form of an authentication message that is sent to the remote peer
236#[derive(Clone, Serialize, Deserialize)]
237struct AuthMessage<S: SignatureKey> {
238    /// The encoded (stake table) public key of the sender. This, along with the peer ID, is
239    /// signed. It is still encoded here to enable easy verification.
240    public_key_bytes: Vec<u8>,
241
242    /// The encoded peer ID of the sender. This is appended to the public key before signing.
243    /// It is still encoded here to enable easy verification.
244    peer_id_bytes: Vec<u8>,
245
246    /// The signature on the public key
247    signature: S::PureAssembledSignatureType,
248}
249
250impl<S: SignatureKey> AuthMessage<S> {
251    /// Validate the signature on the public key and return it if valid
252    pub fn validate(&self) -> AnyhowResult<S> {
253        // Deserialize the stake table public key
254        let public_key = S::from_bytes(&self.public_key_bytes)
255            .with_context(|| "Failed to deserialize public key")?;
256
257        // Reconstruct the signed message from the public key and peer ID
258        let mut signed_message = public_key.to_bytes();
259        signed_message.extend(self.peer_id_bytes.clone());
260
261        // Check if the signature is valid across both
262        if !public_key.validate(&self.signature, &signed_message) {
263            return Err(anyhow::anyhow!("Invalid signature"));
264        }
265
266        Ok(public_key)
267    }
268}
269
270/// Create an sign an authentication message to be sent to the remote peer
271///
272/// # Errors
273/// - If we fail to sign the public key
274/// - If we fail to serialize the authentication message
275pub fn construct_auth_message<S: SignatureKey + 'static>(
276    public_key: &S,
277    peer_id: &PeerId,
278    private_key: &S::PrivateKey,
279) -> AnyhowResult<Vec<u8>> {
280    // Serialize the stake table public key
281    let mut public_key_bytes = public_key.to_bytes();
282
283    // Serialize the peer ID and append it
284    let peer_id_bytes = peer_id.to_bytes();
285    public_key_bytes.extend_from_slice(&peer_id_bytes);
286
287    // Sign our public key
288    let signature =
289        S::sign(private_key, &public_key_bytes).with_context(|| "Failed to sign public key")?;
290
291    // Create the auth message
292    let auth_message = AuthMessage::<S> {
293        public_key_bytes,
294        peer_id_bytes,
295        signature,
296    };
297
298    // Serialize the auth message
299    bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message")
300}
301
302impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin> Transport
303    for ConsensusKeyAuthentication<T, S, C>
304where
305    T::Dial: Future<Output = Result<T::Output, T::Error>> + Send + 'static,
306    T::ListenerUpgrade: Send + 'static,
307    T::Output: AsOutput<C> + Send,
308    T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
309    C::Substream: Unpin + Send,
310{
311    // `Dial` is for connecting out, `ListenerUpgrade` is for accepting incoming connections
312    type Dial = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
313    type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
314
315    // These are just passed through
316    type Output = T::Output;
317    type Error = T::Error;
318
319    /// Dial a remote peer. This function is changed to perform an authentication handshake
320    /// on top.
321    fn dial(
322        &mut self,
323        addr: libp2p::Multiaddr,
324        opts: DialOpts,
325    ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>> {
326        // Perform the inner dial
327        let res = self.inner.dial(addr, opts);
328
329        // Clone the necessary fields
330        let auth_message = Arc::clone(&self.auth_message);
331
332        // If the dial was successful, perform the authentication handshake on top
333        match res {
334            Ok(dial) => Ok(Self::gen_handshake(
335                dial,
336                true,
337                auth_message,
338                Arc::clone(&self.consensus_key_to_pid_map),
339            )),
340            Err(err) => Err(err),
341        }
342    }
343
344    /// This function is where we perform the authentication handshake for _incoming_ connections.
345    /// The flow in this case is the reverse of the `dial` function: we first verify the remote peer's
346    /// authentication, and then authenticate with them.
347    fn poll(
348        mut self: std::pin::Pin<&mut Self>,
349        cx: &mut std::task::Context<'_>,
350    ) -> std::task::Poll<libp2p::core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>>
351    {
352        match Transport::poll(self.as_mut().project().inner, cx) {
353            Poll::Ready(event) => Poll::Ready(match event {
354                // If we have an incoming connection, we need to perform the authentication handshake
355                TransportEvent::Incoming {
356                    listener_id,
357                    upgrade,
358                    local_addr,
359                    send_back_addr,
360                } => {
361                    // Clone the necessary fields
362                    let auth_message = Arc::clone(&self.auth_message);
363
364                    // Generate the handshake upgrade future (inbound)
365                    let auth_upgrade = Self::gen_handshake(
366                        upgrade,
367                        false,
368                        auth_message,
369                        Arc::clone(&self.consensus_key_to_pid_map),
370                    );
371
372                    // Return the new event
373                    TransportEvent::Incoming {
374                        listener_id,
375                        upgrade: auth_upgrade,
376                        local_addr,
377                        send_back_addr,
378                    }
379                },
380
381                // We need to re-map the other events because we changed the type of the upgrade
382                TransportEvent::AddressExpired {
383                    listener_id,
384                    listen_addr,
385                } => TransportEvent::AddressExpired {
386                    listener_id,
387                    listen_addr,
388                },
389                TransportEvent::ListenerClosed {
390                    listener_id,
391                    reason,
392                } => TransportEvent::ListenerClosed {
393                    listener_id,
394                    reason,
395                },
396                TransportEvent::ListenerError { listener_id, error } => {
397                    TransportEvent::ListenerError { listener_id, error }
398                },
399                TransportEvent::NewAddress {
400                    listener_id,
401                    listen_addr,
402                } => TransportEvent::NewAddress {
403                    listener_id,
404                    listen_addr,
405                },
406            }),
407
408            Poll::Pending => Poll::Pending,
409        }
410    }
411
412    /// The below functions just pass through to the inner transport, but we had
413    /// to define them
414    fn remove_listener(&mut self, id: libp2p::core::transport::ListenerId) -> bool {
415        self.inner.remove_listener(id)
416    }
417    fn listen_on(
418        &mut self,
419        id: libp2p::core::transport::ListenerId,
420        addr: libp2p::Multiaddr,
421    ) -> Result<(), libp2p::TransportError<Self::Error>> {
422        self.inner.listen_on(id, addr)
423    }
424}
425
426/// A helper trait that allows us to access the underlying connection
427/// and `PeerId` from a transport output
428trait AsOutput<C: StreamMuxer + Unpin> {
429    /// Get a mutable reference to the underlying connection
430    fn as_connection(&mut self) -> &mut C;
431
432    /// Get a mutable reference to the underlying `PeerId`
433    fn as_peer_id(&mut self) -> &mut PeerId;
434}
435
436/// The implementation of the `AsConnection` trait for a tuple of a `PeerId`
437/// and a connection.
438impl<C: StreamMuxer + Unpin> AsOutput<C> for (PeerId, C) {
439    /// Get a mutable reference to the underlying connection
440    fn as_connection(&mut self) -> &mut C {
441        &mut self.1
442    }
443
444    /// Get a mutable reference to the underlying `PeerId`
445    fn as_peer_id(&mut self) -> &mut PeerId {
446        &mut self.0
447    }
448}
449
450/// A helper function to read a length-delimited message from a stream. Takes into
451/// account the maximum message size.
452///
453/// # Errors
454/// - If the message is too big
455/// - If we fail to read from the stream
456pub async fn read_length_delimited<S: AsyncRead + Unpin>(
457    stream: &mut S,
458    max_size: usize,
459) -> AnyhowResult<Vec<u8>> {
460    // Receive the first 8 bytes of the message, which is the length
461    let mut len_bytes = [0u8; 4];
462    stream
463        .read_exact(&mut len_bytes)
464        .await
465        .with_context(|| "Failed to read message length")?;
466
467    // Parse the length of the message as a `u32`
468    let len = usize::try_from(u32::from_be_bytes(len_bytes))?;
469
470    // Quit if the message is too large
471    ensure!(len <= max_size, "Message too large");
472
473    // Read the actual message
474    let mut message = vec![0u8; len];
475    stream
476        .read_exact(&mut message)
477        .await
478        .with_context(|| "Failed to read message")?;
479
480    Ok(message)
481}
482
483/// A helper function to write a length-delimited message to a stream.
484///
485/// # Errors
486/// - If we fail to write to the stream
487pub async fn write_length_delimited<S: AsyncWrite + Unpin>(
488    stream: &mut S,
489    message: &[u8],
490) -> AnyhowResult<()> {
491    // Write the length of the message
492    stream
493        .write_all(&u32::try_from(message.len())?.to_be_bytes())
494        .await
495        .with_context(|| "Failed to write message length")?;
496
497    // Write the actual message
498    stream
499        .write_all(message)
500        .await
501        .with_context(|| "Failed to write message")?;
502
503    Ok(())
504}
505
506#[cfg(test)]
507mod test {
508    use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey};
509    use libp2p::{core::transport::dummy::DummyTransport, quic::Connection};
510    use rand::Rng;
511
512    use super::*;
513
514    /// A mock type to help with readability
515    type MockStakeTableAuth = ConsensusKeyAuthentication<DummyTransport, BLSPubKey, Connection>;
516
517    // Helper macro for generating a new identity and authentication message
518    macro_rules! new_identity {
519        () => {{
520            // Gen a new seed
521            let seed = rand::rngs::OsRng.r#gen::<[u8; 32]>();
522
523            // Create a new keypair
524            let keypair = BLSPubKey::generated_from_seed_indexed(seed, 1337);
525
526            // Create a peer ID
527            let peer_id = libp2p::identity::Keypair::generate_ed25519()
528                .public()
529                .to_peer_id();
530
531            // Construct an authentication message
532            let auth_message =
533                super::construct_auth_message(&keypair.0, &peer_id, &keypair.1).unwrap();
534
535            (keypair, peer_id, auth_message)
536        }};
537    }
538
539    // Helper macro to generator a cursor from a length-delimited message
540    macro_rules! cursor_from {
541        ($auth_message:expr) => {{
542            let mut stream = futures::io::Cursor::new(vec![]);
543            write_length_delimited(&mut stream, &$auth_message)
544                .await
545                .expect("Failed to write message");
546            stream.set_position(0);
547            stream
548        }};
549    }
550
551    /// Test valid construction and verification of an authentication message
552    #[test]
553    fn signature_verify() {
554        // Create a new identity
555        let (_, _, auth_message) = new_identity!();
556
557        // Verify the authentication message
558        let public_key = super::AuthMessage::<BLSPubKey>::validate(
559            &bincode::deserialize(&auth_message).unwrap(),
560        );
561        assert!(public_key.is_ok());
562    }
563
564    /// Test invalid construction and verification of an authentication message with
565    /// an invalid public key. This ensures we are signing over it correctly.
566    #[test]
567    fn signature_verify_invalid_public_key() {
568        // Create a new identity
569        let (_, _, auth_message) = new_identity!();
570
571        // Deserialize the authentication message
572        let mut auth_message: super::AuthMessage<BLSPubKey> =
573            bincode::deserialize(&auth_message).unwrap();
574
575        // Change the public key
576        auth_message.public_key_bytes[0] ^= 0x01;
577
578        // Serialize the message again
579        let auth_message = bincode::serialize(&auth_message).unwrap();
580
581        // Verify the authentication message
582        let public_key = super::AuthMessage::<BLSPubKey>::validate(
583            &bincode::deserialize(&auth_message).unwrap(),
584        );
585        assert!(public_key.is_err());
586    }
587
588    /// Test invalid construction and verification of an authentication message with
589    /// an invalid peer ID. This ensures we are signing over it correctly.
590    #[test]
591    fn signature_verify_invalid_peer_id() {
592        // Create a new identity
593        let (_, _, auth_message) = new_identity!();
594
595        // Deserialize the authentication message
596        let mut auth_message: super::AuthMessage<BLSPubKey> =
597            bincode::deserialize(&auth_message).unwrap();
598
599        // Change the peer ID
600        auth_message.peer_id_bytes[0] ^= 0x01;
601
602        // Serialize the message again
603        let auth_message = bincode::serialize(&auth_message).unwrap();
604
605        // Verify the authentication message
606        let public_key = super::AuthMessage::<BLSPubKey>::validate(
607            &bincode::deserialize(&auth_message).unwrap(),
608        );
609        assert!(public_key.is_err());
610    }
611
612    #[tokio::test(flavor = "multi_thread")]
613    async fn valid_authentication() {
614        // Create a new identity
615        let (keypair, peer_id, auth_message) = new_identity!();
616
617        // Create a stream and write the message to it
618        let mut stream = cursor_from!(auth_message);
619
620        // Create a map from consensus keys to peer IDs
621        let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
622
623        // Verify the authentication message
624        let result = MockStakeTableAuth::verify_peer_authentication(
625            &mut stream,
626            &peer_id,
627            Arc::clone(&consensus_key_to_pid_map),
628        )
629        .await;
630
631        // Make sure the map has the correct entry
632        assert!(
633            consensus_key_to_pid_map
634                .lock()
635                .get_by_left(&keypair.0)
636                .unwrap()
637                == &peer_id,
638            "Map does not have the correct entry"
639        );
640
641        assert!(
642            result.is_ok(),
643            "Should have passed authentication but did not"
644        );
645    }
646
647    #[tokio::test(flavor = "multi_thread")]
648    async fn peer_id_mismatch() {
649        // Create a new identity and authentication message
650        let (_, _, auth_message) = new_identity!();
651
652        // Create a second (malicious) identity
653        let (_, malicious_peer_id, _) = new_identity!();
654
655        // Create a stream and write the message to it
656        let mut stream = cursor_from!(auth_message);
657
658        // Create a map from consensus keys to peer IDs
659        let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
660
661        // Check against the malicious peer ID
662        let result = MockStakeTableAuth::verify_peer_authentication(
663            &mut stream,
664            &malicious_peer_id,
665            Arc::clone(&consensus_key_to_pid_map),
666        )
667        .await;
668
669        // Make sure it errored for the right reason
670        assert!(
671            result
672                .expect_err("Should have failed authentication but did not")
673                .to_string()
674                .contains("Peer ID mismatch"),
675            "Did not fail with the correct error"
676        );
677
678        // Make sure the map does not have the malicious peer ID
679        assert!(
680            consensus_key_to_pid_map.lock().is_empty(),
681            "Malicious peer ID should not be in the map"
682        );
683    }
684
685    #[tokio::test(flavor = "multi_thread")]
686    async fn read_and_write_length_delimited() {
687        // Create a message
688        let message = b"Hello, world!";
689
690        // Write the message to a buffer
691        let mut buffer = Vec::new();
692        write_length_delimited(&mut buffer, message).await.unwrap();
693
694        // Read the message from the buffer
695        let read_message = read_length_delimited(&mut buffer.as_slice(), 1024)
696            .await
697            .unwrap();
698
699        // Check if the messages are the same
700        assert_eq!(message, read_message.as_slice());
701    }
702}