1use std::{
2 future::Future,
3 io::{Error as IoError, ErrorKind as IoErrorKind},
4 pin::Pin,
5 sync::Arc,
6 task::Poll,
7};
8
9use anyhow::{Context, Result as AnyhowResult, ensure};
10use bimap::BiMap;
11use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, future::poll_fn};
12use hotshot_types::traits::signature_key::SignatureKey;
13use libp2p::{
14 Transport,
15 core::{
16 StreamMuxer,
17 muxing::StreamMuxerExt,
18 transport::{DialOpts, TransportEvent},
19 },
20 identity::PeerId,
21};
22use parking_lot::Mutex;
23use pin_project::pin_project;
24use serde::{Deserialize, Serialize};
25use tokio::time::timeout;
26use tracing::debug;
27
28use crate::network::log_summary::LogEvent;
29
30const MAX_AUTH_MESSAGE_SIZE: usize = 1024;
33
34const AUTH_HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
38
39#[pin_project]
42pub struct ConsensusKeyAuthentication<
43 T: Transport,
44 S: SignatureKey + 'static,
45 C: StreamMuxer + Unpin,
46> {
47 #[pin]
48 pub inner: T,
50
51 pub auth_message: Arc<Option<Vec<u8>>>,
53
54 pub consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
56
57 pd: std::marker::PhantomData<(C, S)>,
59}
60
61type UpgradeFuture<T> =
63 Pin<Box<dyn Future<Output = Result<<T as Transport>::Output, <T as Transport>::Error>> + Send>>;
64
65impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin>
66 ConsensusKeyAuthentication<T, S, C>
67{
68 pub fn new(
72 inner: T,
73 auth_message: Option<Vec<u8>>,
74 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
75 ) -> Self {
76 Self {
77 inner,
78 auth_message: Arc::from(auth_message),
79 consensus_key_to_pid_map,
80 pd: std::marker::PhantomData,
81 }
82 }
83
84 pub async fn authenticate_with_remote_peer<W: AsyncWrite + Unpin>(
90 stream: &mut W,
91 auth_message: &[u8],
92 ) -> AnyhowResult<()> {
93 write_length_delimited(stream, auth_message).await?;
95
96 Ok(())
97 }
98
99 pub async fn verify_peer_authentication<R: AsyncReadExt + Unpin>(
113 stream: &mut R,
114 required_peer_id: &PeerId,
115 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
116 ) -> AnyhowResult<()> {
117 let message = read_length_delimited(stream, MAX_AUTH_MESSAGE_SIZE).await?;
119
120 let auth_message: AuthMessage<S> =
122 bincode::deserialize(&message).with_context(|| "Failed to deserialize auth message")?;
123
124 let public_key = auth_message
126 .validate()
127 .with_context(|| "Failed to verify authentication message")?;
128
129 let peer_id = PeerId::from_bytes(&auth_message.peer_id_bytes)
131 .with_context(|| "Failed to deserialize peer ID")?;
132
133 if peer_id != *required_peer_id {
135 return Err(anyhow::anyhow!("Peer ID mismatch"));
136 }
137
138 consensus_key_to_pid_map.lock().insert(public_key, peer_id);
140
141 Ok(())
142 }
143
144 fn gen_handshake<F: Future<Output = Result<T::Output, T::Error>> + Send + 'static>(
149 original_future: F,
150 outgoing: bool,
151 auth_message: Arc<Option<Vec<u8>>>,
152 consensus_key_to_pid_map: Arc<Mutex<BiMap<S, PeerId>>>,
153 ) -> UpgradeFuture<T>
154 where
155 T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
156 T::Output: AsOutput<C> + Send,
157 C::Substream: Unpin + Send,
158 {
159 Box::pin(async move {
161 let mut stream = original_future.await?;
163
164 timeout(AUTH_HANDSHAKE_TIMEOUT, async {
166 let mut substream = if outgoing {
169 poll_fn(|cx| stream.as_connection().poll_outbound_unpin(cx)).await?
170 } else {
171 poll_fn(|cx| stream.as_connection().poll_inbound_unpin(cx)).await?
172 };
173
174 if let Some(auth_message) = auth_message.as_ref() {
176 if outgoing {
177 Self::authenticate_with_remote_peer(&mut substream, auth_message)
179 .await
180 .map_err(|e| {
181 LogEvent::AuthFailure.record();
182 debug!("Failed to authenticate with remote peer: {e:?}");
183 IoError::other(e)
184 })?;
185
186 Self::verify_peer_authentication(
188 &mut substream,
189 stream.as_peer_id(),
190 consensus_key_to_pid_map,
191 )
192 .await
193 .map_err(|e| {
194 LogEvent::VerifyFailure.record();
195 debug!("Failed to verify remote peer: {e:?}");
196 IoError::other(e)
197 })?;
198 } else {
199 Self::verify_peer_authentication(
201 &mut substream,
202 stream.as_peer_id(),
203 consensus_key_to_pid_map,
204 )
205 .await
206 .map_err(|e| {
207 LogEvent::VerifyFailure.record();
208 debug!("Failed to verify remote peer: {e:?}");
209 IoError::other(e)
210 })?;
211
212 Self::authenticate_with_remote_peer(&mut substream, auth_message)
214 .await
215 .map_err(|e| {
216 LogEvent::AuthFailure.record();
217 debug!("Failed to authenticate with remote peer: {e:?}");
218 IoError::other(e)
219 })?;
220 }
221 }
222
223 Ok(stream)
224 })
225 .await
226 .map_err(|e| {
227 LogEvent::AuthHandshakeTimeout.record();
228 debug!("Timed out performing authentication handshake: {e:?}");
229 IoError::new(IoErrorKind::TimedOut, e)
230 })?
231 })
232 }
233}
234
235#[derive(Clone, Serialize, Deserialize)]
237struct AuthMessage<S: SignatureKey> {
238 public_key_bytes: Vec<u8>,
241
242 peer_id_bytes: Vec<u8>,
245
246 signature: S::PureAssembledSignatureType,
248}
249
250impl<S: SignatureKey> AuthMessage<S> {
251 pub fn validate(&self) -> AnyhowResult<S> {
253 let public_key = S::from_bytes(&self.public_key_bytes)
255 .with_context(|| "Failed to deserialize public key")?;
256
257 let mut signed_message = public_key.to_bytes();
259 signed_message.extend(self.peer_id_bytes.clone());
260
261 if !public_key.validate(&self.signature, &signed_message) {
263 return Err(anyhow::anyhow!("Invalid signature"));
264 }
265
266 Ok(public_key)
267 }
268}
269
270pub fn construct_auth_message<S: SignatureKey + 'static>(
276 public_key: &S,
277 peer_id: &PeerId,
278 private_key: &S::PrivateKey,
279) -> AnyhowResult<Vec<u8>> {
280 let mut public_key_bytes = public_key.to_bytes();
282
283 let peer_id_bytes = peer_id.to_bytes();
285 public_key_bytes.extend_from_slice(&peer_id_bytes);
286
287 let signature =
289 S::sign(private_key, &public_key_bytes).with_context(|| "Failed to sign public key")?;
290
291 let auth_message = AuthMessage::<S> {
293 public_key_bytes,
294 peer_id_bytes,
295 signature,
296 };
297
298 bincode::serialize(&auth_message).with_context(|| "Failed to serialize auth message")
300}
301
302impl<T: Transport, S: SignatureKey + 'static, C: StreamMuxer + Unpin> Transport
303 for ConsensusKeyAuthentication<T, S, C>
304where
305 T::Dial: Future<Output = Result<T::Output, T::Error>> + Send + 'static,
306 T::ListenerUpgrade: Send + 'static,
307 T::Output: AsOutput<C> + Send,
308 T::Error: From<<C as StreamMuxer>::Error> + From<IoError>,
309 C::Substream: Unpin + Send,
310{
311 type Dial = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
313 type ListenerUpgrade = Pin<Box<dyn Future<Output = Result<T::Output, T::Error>> + Send>>;
314
315 type Output = T::Output;
317 type Error = T::Error;
318
319 fn dial(
322 &mut self,
323 addr: libp2p::Multiaddr,
324 opts: DialOpts,
325 ) -> Result<Self::Dial, libp2p::TransportError<Self::Error>> {
326 let res = self.inner.dial(addr, opts);
328
329 let auth_message = Arc::clone(&self.auth_message);
331
332 match res {
334 Ok(dial) => Ok(Self::gen_handshake(
335 dial,
336 true,
337 auth_message,
338 Arc::clone(&self.consensus_key_to_pid_map),
339 )),
340 Err(err) => Err(err),
341 }
342 }
343
344 fn poll(
348 mut self: std::pin::Pin<&mut Self>,
349 cx: &mut std::task::Context<'_>,
350 ) -> std::task::Poll<libp2p::core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>>
351 {
352 match Transport::poll(self.as_mut().project().inner, cx) {
353 Poll::Ready(event) => Poll::Ready(match event {
354 TransportEvent::Incoming {
356 listener_id,
357 upgrade,
358 local_addr,
359 send_back_addr,
360 } => {
361 let auth_message = Arc::clone(&self.auth_message);
363
364 let auth_upgrade = Self::gen_handshake(
366 upgrade,
367 false,
368 auth_message,
369 Arc::clone(&self.consensus_key_to_pid_map),
370 );
371
372 TransportEvent::Incoming {
374 listener_id,
375 upgrade: auth_upgrade,
376 local_addr,
377 send_back_addr,
378 }
379 },
380
381 TransportEvent::AddressExpired {
383 listener_id,
384 listen_addr,
385 } => TransportEvent::AddressExpired {
386 listener_id,
387 listen_addr,
388 },
389 TransportEvent::ListenerClosed {
390 listener_id,
391 reason,
392 } => TransportEvent::ListenerClosed {
393 listener_id,
394 reason,
395 },
396 TransportEvent::ListenerError { listener_id, error } => {
397 TransportEvent::ListenerError { listener_id, error }
398 },
399 TransportEvent::NewAddress {
400 listener_id,
401 listen_addr,
402 } => TransportEvent::NewAddress {
403 listener_id,
404 listen_addr,
405 },
406 }),
407
408 Poll::Pending => Poll::Pending,
409 }
410 }
411
412 fn remove_listener(&mut self, id: libp2p::core::transport::ListenerId) -> bool {
415 self.inner.remove_listener(id)
416 }
417 fn listen_on(
418 &mut self,
419 id: libp2p::core::transport::ListenerId,
420 addr: libp2p::Multiaddr,
421 ) -> Result<(), libp2p::TransportError<Self::Error>> {
422 self.inner.listen_on(id, addr)
423 }
424}
425
426trait AsOutput<C: StreamMuxer + Unpin> {
429 fn as_connection(&mut self) -> &mut C;
431
432 fn as_peer_id(&mut self) -> &mut PeerId;
434}
435
436impl<C: StreamMuxer + Unpin> AsOutput<C> for (PeerId, C) {
439 fn as_connection(&mut self) -> &mut C {
441 &mut self.1
442 }
443
444 fn as_peer_id(&mut self) -> &mut PeerId {
446 &mut self.0
447 }
448}
449
450pub async fn read_length_delimited<S: AsyncRead + Unpin>(
457 stream: &mut S,
458 max_size: usize,
459) -> AnyhowResult<Vec<u8>> {
460 let mut len_bytes = [0u8; 4];
462 stream
463 .read_exact(&mut len_bytes)
464 .await
465 .with_context(|| "Failed to read message length")?;
466
467 let len = usize::try_from(u32::from_be_bytes(len_bytes))?;
469
470 ensure!(len <= max_size, "Message too large");
472
473 let mut message = vec![0u8; len];
475 stream
476 .read_exact(&mut message)
477 .await
478 .with_context(|| "Failed to read message")?;
479
480 Ok(message)
481}
482
483pub async fn write_length_delimited<S: AsyncWrite + Unpin>(
488 stream: &mut S,
489 message: &[u8],
490) -> AnyhowResult<()> {
491 stream
493 .write_all(&u32::try_from(message.len())?.to_be_bytes())
494 .await
495 .with_context(|| "Failed to write message length")?;
496
497 stream
499 .write_all(message)
500 .await
501 .with_context(|| "Failed to write message")?;
502
503 Ok(())
504}
505
506#[cfg(test)]
507mod test {
508 use hotshot_types::{signature_key::BLSPubKey, traits::signature_key::SignatureKey};
509 use libp2p::{core::transport::dummy::DummyTransport, quic::Connection};
510 use rand::Rng;
511
512 use super::*;
513
514 type MockStakeTableAuth = ConsensusKeyAuthentication<DummyTransport, BLSPubKey, Connection>;
516
517 macro_rules! new_identity {
519 () => {{
520 let seed = rand::rngs::OsRng.r#gen::<[u8; 32]>();
522
523 let keypair = BLSPubKey::generated_from_seed_indexed(seed, 1337);
525
526 let peer_id = libp2p::identity::Keypair::generate_ed25519()
528 .public()
529 .to_peer_id();
530
531 let auth_message =
533 super::construct_auth_message(&keypair.0, &peer_id, &keypair.1).unwrap();
534
535 (keypair, peer_id, auth_message)
536 }};
537 }
538
539 macro_rules! cursor_from {
541 ($auth_message:expr) => {{
542 let mut stream = futures::io::Cursor::new(vec![]);
543 write_length_delimited(&mut stream, &$auth_message)
544 .await
545 .expect("Failed to write message");
546 stream.set_position(0);
547 stream
548 }};
549 }
550
551 #[test]
553 fn signature_verify() {
554 let (_, _, auth_message) = new_identity!();
556
557 let public_key = super::AuthMessage::<BLSPubKey>::validate(
559 &bincode::deserialize(&auth_message).unwrap(),
560 );
561 assert!(public_key.is_ok());
562 }
563
564 #[test]
567 fn signature_verify_invalid_public_key() {
568 let (_, _, auth_message) = new_identity!();
570
571 let mut auth_message: super::AuthMessage<BLSPubKey> =
573 bincode::deserialize(&auth_message).unwrap();
574
575 auth_message.public_key_bytes[0] ^= 0x01;
577
578 let auth_message = bincode::serialize(&auth_message).unwrap();
580
581 let public_key = super::AuthMessage::<BLSPubKey>::validate(
583 &bincode::deserialize(&auth_message).unwrap(),
584 );
585 assert!(public_key.is_err());
586 }
587
588 #[test]
591 fn signature_verify_invalid_peer_id() {
592 let (_, _, auth_message) = new_identity!();
594
595 let mut auth_message: super::AuthMessage<BLSPubKey> =
597 bincode::deserialize(&auth_message).unwrap();
598
599 auth_message.peer_id_bytes[0] ^= 0x01;
601
602 let auth_message = bincode::serialize(&auth_message).unwrap();
604
605 let public_key = super::AuthMessage::<BLSPubKey>::validate(
607 &bincode::deserialize(&auth_message).unwrap(),
608 );
609 assert!(public_key.is_err());
610 }
611
612 #[tokio::test(flavor = "multi_thread")]
613 async fn valid_authentication() {
614 let (keypair, peer_id, auth_message) = new_identity!();
616
617 let mut stream = cursor_from!(auth_message);
619
620 let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
622
623 let result = MockStakeTableAuth::verify_peer_authentication(
625 &mut stream,
626 &peer_id,
627 Arc::clone(&consensus_key_to_pid_map),
628 )
629 .await;
630
631 assert!(
633 consensus_key_to_pid_map
634 .lock()
635 .get_by_left(&keypair.0)
636 .unwrap()
637 == &peer_id,
638 "Map does not have the correct entry"
639 );
640
641 assert!(
642 result.is_ok(),
643 "Should have passed authentication but did not"
644 );
645 }
646
647 #[tokio::test(flavor = "multi_thread")]
648 async fn peer_id_mismatch() {
649 let (_, _, auth_message) = new_identity!();
651
652 let (_, malicious_peer_id, _) = new_identity!();
654
655 let mut stream = cursor_from!(auth_message);
657
658 let consensus_key_to_pid_map = Arc::new(parking_lot::Mutex::new(BiMap::new()));
660
661 let result = MockStakeTableAuth::verify_peer_authentication(
663 &mut stream,
664 &malicious_peer_id,
665 Arc::clone(&consensus_key_to_pid_map),
666 )
667 .await;
668
669 assert!(
671 result
672 .expect_err("Should have failed authentication but did not")
673 .to_string()
674 .contains("Peer ID mismatch"),
675 "Did not fail with the correct error"
676 );
677
678 assert!(
680 consensus_key_to_pid_map.lock().is_empty(),
681 "Malicious peer ID should not be in the map"
682 );
683 }
684
685 #[tokio::test(flavor = "multi_thread")]
686 async fn read_and_write_length_delimited() {
687 let message = b"Hello, world!";
689
690 let mut buffer = Vec::new();
692 write_length_delimited(&mut buffer, message).await.unwrap();
693
694 let read_message = read_length_delimited(&mut buffer.as_slice(), 1024)
696 .await
697 .unwrap();
698
699 assert_eq!(message, read_message.as_slice());
701 }
702}