request_response/
lib.rs

1//! This crate contains a general request-response protocol. It is used to send requests to
2//! a set of recipients and wait for responses.
3
4use std::{
5    any::Any,
6    collections::HashMap,
7    future::Future,
8    marker::PhantomData,
9    pin::Pin,
10    sync::{Arc, Weak},
11    time::{Duration, Instant},
12};
13
14use anyhow::{Context, Result, anyhow};
15use async_lock::RwLock;
16use data_source::DataSource;
17use derive_more::derive::Deref;
18use hotshot_types::traits::signature_key::SignatureKey;
19use message::{Message, RequestMessage, ResponseMessage};
20use network::{Bytes, Receiver, Sender};
21use rand::seq::SliceRandom;
22use recipient_source::RecipientSource;
23use request::Request;
24use tokio::{
25    spawn,
26    time::{sleep, timeout},
27};
28use tokio_util::task::AbortOnDropHandle;
29use tracing::{debug, error, info, trace, warn};
30use util::{BoundedVecDeque, NamedSemaphore, NamedSemaphoreError};
31
32use crate::util::NamedSemaphorePermit;
33
34/// The data source trait. Is what we use to derive the response data for a request
35pub mod data_source;
36/// The message type. Is the base type for all messages in the request-response protocol
37pub mod message;
38/// The network traits. Is what we use to send and receive messages over the network as
39/// the protocol
40pub mod network;
41/// The recipient source trait. Is what we use to get the recipients that a specific message should
42/// expect responses from
43pub mod recipient_source;
44/// The request trait. Is what we use to define a request and a corresponding response type
45pub mod request;
46/// Utility types and functions
47mod util;
48
49/// A type alias for the hash of a request
50pub type RequestHash = blake3::Hash;
51
52/// A type alias for the outgoing requests map
53pub type OutgoingRequestsMap<Req> =
54    Arc<RwLock<HashMap<RequestHash, Weak<OutgoingRequestInner<Req>>>>>;
55
56/// A type alias for the list of tasks that are responding to requests
57pub type IncomingRequests<K> = NamedSemaphore<K>;
58
59/// A type alias for the list of tasks that are validating incoming responses
60pub type IncomingResponses = NamedSemaphore<RequestHash>;
61
62/// The type of request to make
63#[derive(PartialEq, Eq, Clone, Copy)]
64pub enum RequestType {
65    /// A request that can be satisfied by a single participant,
66    /// and as such will be batched to a few participants at a time
67    /// until one succeeds
68    Batched,
69    /// A request that needs most or all participants to respond,
70    /// and as such will be broadcasted to all participants
71    Broadcast,
72}
73
74/// The errors that can occur when making a request for data
75#[derive(thiserror::Error, Debug)]
76pub enum RequestError {
77    /// The request timed out
78    #[error("request timed out")]
79    Timeout,
80    /// The request was invalid
81    #[error("request was invalid")]
82    InvalidRequest(anyhow::Error),
83    /// Other errors
84    #[error("other error")]
85    Other(anyhow::Error),
86}
87
88/// A trait for serializing and deserializing a type to and from a byte array. [`Request`] types and
89/// [`Response`] types will need to implement this trait
90pub trait Serializable: Sized {
91    /// Serialize the type to a byte array. If this is for a [`Request`] and your [`Request`] type
92    /// is represented as an enum, please make sure that you serialize it with a unique type ID. Otherwise,
93    /// you may end up with collisions as the request hash is used as a unique identifier
94    ///
95    /// # Errors
96    /// - If the type cannot be serialized to a byte array
97    fn to_bytes(&self) -> Result<Vec<u8>>;
98
99    /// Deserialize the type from a byte array
100    ///
101    /// # Errors
102    /// - If the byte array is not a valid representation of the type
103    fn from_bytes(bytes: &[u8]) -> Result<Self>;
104}
105
106/// The underlying configuration for the request-response protocol
107#[derive(Clone)]
108pub struct RequestResponseConfig {
109    /// The timeout for incoming requests. Do not respond to a request after this threshold
110    /// has passed.
111    pub incoming_request_ttl: Duration,
112    /// The maximum amount of time we will spend trying to both derive a response for a request and
113    /// send the response over the wire.
114    pub incoming_request_timeout: Duration,
115    /// The maximum amount of time we will spend trying to validate a response. This is used to prevent
116    /// an attack where a malicious participant sends us a bunch of requests that take a long time to
117    /// validate.
118    pub incoming_response_timeout: Duration,
119    /// The batch size for outgoing requests. This is the number of request messages that we will
120    /// send out at a time for a single request before waiting for the [`request_batch_interval`].
121    pub request_batch_size: usize,
122    /// The time to wait (per request) between sending out batches of request messages
123    pub request_batch_interval: Duration,
124    /// The maximum (global) number of incoming requests that can be processed at any given time.
125    pub max_incoming_requests: usize,
126    /// The maximum number of incoming requests that can be processed for a single key at any given time.
127    pub max_incoming_requests_per_key: usize,
128    /// The maximum (global) number of incoming responses that can be processed at any given time.
129    /// We need this because responses coming in need to be validated [asynchronously] that they
130    /// satisfy the request they are responding to
131    pub max_incoming_responses: usize,
132}
133
134/// A protocol that allows for request-response communication. Is cheaply cloneable, so there is no
135/// need to wrap it in an `Arc`
136#[derive(Deref)]
137pub struct RequestResponse<
138    S: Sender<K>,
139    R: Receiver,
140    Req: Request,
141    RS: RecipientSource<Req, K>,
142    DS: DataSource<Req>,
143    K: SignatureKey + 'static,
144> {
145    #[deref]
146    /// The inner implementation of the request-response protocol
147    pub inner: Arc<RequestResponseInner<S, R, Req, RS, DS, K>>,
148    /// A handle to the receiving task. This will automatically get cancelled when the protocol is dropped
149    _receiving_task_handle: Arc<AbortOnDropHandle<()>>,
150}
151
152/// We need to manually implement the `Clone` trait for this type because deriving
153/// `Deref` will cause an issue where it tries to clone the inner field instead
154impl<
155    S: Sender<K>,
156    R: Receiver,
157    Req: Request,
158    RS: RecipientSource<Req, K>,
159    DS: DataSource<Req>,
160    K: SignatureKey + 'static,
161> Clone for RequestResponse<S, R, Req, RS, DS, K>
162{
163    fn clone(&self) -> Self {
164        Self {
165            inner: Arc::clone(&self.inner),
166            _receiving_task_handle: Arc::clone(&self._receiving_task_handle),
167        }
168    }
169}
170
171impl<
172    S: Sender<K>,
173    R: Receiver,
174    Req: Request,
175    RS: RecipientSource<Req, K>,
176    DS: DataSource<Req>,
177    K: SignatureKey + 'static,
178> RequestResponse<S, R, Req, RS, DS, K>
179{
180    /// Create a new [`RequestResponseProtocol`]
181    pub fn new(
182        // The configuration for the protocol
183        config: RequestResponseConfig,
184        // The network sender that [`RequestResponseProtocol`] will use to send messages
185        sender: S,
186        // The network receiver that [`RequestResponseProtocol`] will use to receive messages
187        receiver: R,
188        // The recipient source that [`RequestResponseProtocol`] will use to get the recipients
189        // that a specific message should expect responses from
190        recipient_source: RS,
191        // The [response] data source that [`RequestResponseProtocol`] will use to derive the
192        // response data for a specific request
193        data_source: DS,
194    ) -> Self {
195        // Create the outgoing requests map
196        let outgoing_requests = OutgoingRequestsMap::default();
197
198        // Create the inner implementation
199        let inner = Arc::new(RequestResponseInner {
200            config,
201            sender,
202            recipient_source,
203            data_source,
204            outgoing_requests,
205            phantom_data: PhantomData,
206        });
207
208        // Start the task that receives messages and handles them. This will automatically get cancelled
209        // when the protocol is dropped
210        let inner_clone = Arc::clone(&inner);
211        let receive_task_handle =
212            AbortOnDropHandle::new(tokio::spawn(inner_clone.receiving_task(receiver)));
213
214        // Return the protocol
215        Self {
216            inner,
217            _receiving_task_handle: Arc::new(receive_task_handle),
218        }
219    }
220}
221
222/// A type alias for an `Arc<dyn Any + Send + Sync + 'static>`
223type ThreadSafeAny = Arc<dyn Any + Send + Sync + 'static>;
224
225/// A type alias for the future that validates a response
226type ResponseValidationFuture =
227    Pin<Box<dyn Future<Output = Result<ThreadSafeAny, anyhow::Error>> + Send + Sync + 'static>>;
228
229/// A type alias for the function that returns the above future
230type ResponseValidationFn<R> =
231    Box<dyn Fn(&R, <R as Request>::Response) -> ResponseValidationFuture + Send + Sync + 'static>;
232
233/// The inner implementation for the request-response protocol
234pub struct RequestResponseInner<
235    S: Sender<K>,
236    R: Receiver,
237    Req: Request,
238    RS: RecipientSource<Req, K>,
239    DS: DataSource<Req>,
240    K: SignatureKey + 'static,
241> {
242    /// The configuration of the protocol
243    config: RequestResponseConfig,
244    /// The sender to use for the protocol
245    pub sender: S,
246    /// The recipient source to use for the protocol
247    pub recipient_source: RS,
248    /// The data source to use for the protocol
249    data_source: DS,
250    /// The map of currently active, outgoing requests
251    outgoing_requests: OutgoingRequestsMap<Req>,
252    /// Phantom data to help with type inference
253    phantom_data: PhantomData<(K, R, Req, DS)>,
254}
255impl<
256    S: Sender<K>,
257    R: Receiver,
258    Req: Request,
259    RS: RecipientSource<Req, K>,
260    DS: DataSource<Req>,
261    K: SignatureKey + 'static,
262> RequestResponseInner<S, R, Req, RS, DS, K>
263{
264    /// Request something from the protocol indefinitely until we get a response
265    /// or there was a critical error (e.g. the request could not be signed)
266    ///
267    /// # Errors
268    /// - If the request was invalid
269    /// - If there was a critical error (e.g. the channel was closed)
270    pub async fn request_indefinitely<F, Fut, O>(
271        self: &Arc<Self>,
272        public_key: &K,
273        private_key: &K::PrivateKey,
274        // The type of request to make
275        request_type: RequestType,
276        // The estimated TTL of other participants. This is used to decide when to
277        // stop making requests and sign a new one
278        estimated_request_ttl: Duration,
279        // The request to make
280        request: Req,
281        // The response validation function
282        response_validation_fn: F,
283    ) -> std::result::Result<O, RequestError>
284    where
285        F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
286        Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
287        O: Send + Sync + 'static + Clone,
288    {
289        loop {
290            // Sign a request message
291            let request_message = RequestMessage::new_signed(public_key, private_key, &request)
292                .map_err(|e| {
293                    RequestError::InvalidRequest(anyhow::anyhow!(
294                        "failed to sign request message: {e}"
295                    ))
296                })?;
297
298            // Request the data, handling the errors appropriately
299            match self
300                .request(
301                    request_message,
302                    request_type,
303                    estimated_request_ttl,
304                    response_validation_fn.clone(),
305                )
306                .await
307            {
308                Ok(response) => return Ok(response),
309                Err(RequestError::Timeout) => continue,
310                Err(e) => return Err(e),
311            }
312        }
313    }
314
315    /// Request something from the protocol and wait for the response. This function
316    /// will join with an existing request for the same data (determined by `Blake3` hash),
317    /// however both will make requests until the timeout is reached
318    ///
319    /// # Errors
320    /// - If the request times out
321    /// - If the channel is closed (this is an internal error)
322    /// - If the request we sign is invalid
323    pub async fn request<F, Fut, O>(
324        self: &Arc<Self>,
325        request_message: RequestMessage<Req, K>,
326        request_type: RequestType,
327        timeout_duration: Duration,
328        response_validation_fn: F,
329    ) -> std::result::Result<O, RequestError>
330    where
331        F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
332        Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
333        O: Send + Sync + 'static + Clone,
334    {
335        timeout(timeout_duration, async move {
336            // Calculate the hash of the request
337            let request_hash = blake3::hash(&request_message.request.to_bytes().map_err(|e| {
338                RequestError::InvalidRequest(anyhow::anyhow!(
339                    "failed to serialize request message: {e}"
340                ))
341            })?);
342
343            let request = {
344                // Get a write lock on the outgoing requests map
345                let mut outgoing_requests_write = self.outgoing_requests.write().await;
346
347                // Conditionally get the outgoing request, creating a new one if it doesn't exist or if
348                // the existing one has been dropped and not yet removed
349                if let Some(outgoing_request) = outgoing_requests_write
350                    .get(&request_hash)
351                    .and_then(Weak::upgrade)
352                {
353                    OutgoingRequest(outgoing_request)
354                } else {
355                    // Create a new broadcast channel for the response
356                    let (sender, receiver) = async_broadcast::broadcast(1);
357
358                    // Modify the response validation function to return an `Arc<dyn Any>`
359                    let response_validation_fn =
360                        Box::new(move |request: &Req, response: Req::Response| {
361                            let fut = response_validation_fn(request, response);
362                            Box::pin(
363                                async move { fut.await.map(|ok| Arc::new(ok) as ThreadSafeAny) },
364                            ) as ResponseValidationFuture
365                        });
366
367                    // Create a new outgoing request
368                    let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
369                        sender,
370                        receiver,
371                        response_validation_fn,
372                        request: request_message.request.clone(),
373                    }));
374
375                    // Write the new outgoing request to the map
376                    outgoing_requests_write
377                        .insert(request_hash, Arc::downgrade(&outgoing_request.0));
378
379                    // Return the new outgoing request
380                    outgoing_request
381                }
382            };
383
384            // Create a request message and serialize it
385            let message = Bytes::from(
386                Message::Request(request_message.clone())
387                    .to_bytes()
388                    .map_err(|e| {
389                        RequestError::InvalidRequest(anyhow::anyhow!(
390                            "failed to serialize request message: {e}"
391                        ))
392                    })?,
393            );
394
395            // Create a place to put the handle for the batched sending task. We need this because
396            // otherwise it gets dropped when the closure goes out of scope, instead of when the function
397            // gets cancelled or returns
398            let mut _batched_sending_task = None;
399
400            // Match on the type of request
401            if request_type == RequestType::Broadcast {
402                trace!("Sending request {request_message:?} to all participants");
403
404                // If the message is a broadcast request, just send it to all participants
405                self.sender
406                    .send_broadcast_message(&message)
407                    .await
408                    .map_err(|e| {
409                        RequestError::Other(anyhow::anyhow!(
410                            "failed to send broadcast message: {e}"
411                        ))
412                    })?;
413            } else {
414                // If the message is a batched request, we need to batch it with other requests
415
416                // Get the recipients that the request should expect responses from. Shuffle them so
417                // that we don't always send to the same recipients in the same order
418                let mut recipients = self
419                    .recipient_source
420                    .get_expected_responders(&request_message.request)
421                    .await
422                    .map_err(|e| {
423                        RequestError::InvalidRequest(anyhow::anyhow!(
424                            "failed to get expected responders for request: {e}"
425                        ))
426                    })?;
427                recipients.shuffle(&mut rand::thread_rng());
428
429                // Get the current time so we can check when the timeout has elapsed
430                let start_time = Instant::now();
431
432                // Spawn a task that sends out requests to the network
433                let self_clone = Arc::clone(self);
434                let batched_sending_handle = AbortOnDropHandle::new(spawn(async move {
435                    // Create a bounded queue for the outgoing requests. We use this to make sure
436                    // we have less than [`config.request_batch_size`] requests in flight at any time.
437                    //
438                    // When newer requests are added, older ones are removed from the queue. Because we use
439                    // `AbortOnDropHandle`, the older ones will automatically get cancelled
440                    let mut outgoing_requests =
441                        BoundedVecDeque::new(self_clone.config.request_batch_size);
442
443                    // While the timeout hasn't elapsed, send out requests to the network
444                    while start_time.elapsed() < timeout_duration {
445                        // Send out requests to the network in their own separate tasks
446                        for recipient_batch in
447                            recipients.chunks(self_clone.config.request_batch_size)
448                        {
449                            for recipient in recipient_batch {
450                                // Clone ourselves, the message, and the recipient so they can be moved
451                                let self_clone = Arc::clone(&self_clone);
452                                let request_message_clone = request_message.clone();
453                                let recipient_clone = recipient.clone();
454                                let message_clone = Arc::clone(&message);
455
456                                // Spawn the task that sends the request to the participant
457                                let individual_sending_task = spawn(async move {
458                                    trace!(
459                                        "Sending request {request_message_clone:?} to \
460                                         {recipient_clone:?}"
461                                    );
462
463                                    let _ = self_clone
464                                        .sender
465                                        .send_direct_message(&message_clone, recipient_clone)
466                                        .await;
467                                });
468
469                                // Add the sending task to the queue
470                                outgoing_requests
471                                    .push(AbortOnDropHandle::new(individual_sending_task));
472                            }
473
474                            // After we send the batch out, wait the [`config.request_batch_interval`]
475                            // before sending the next one
476                            sleep(self_clone.config.request_batch_interval).await;
477                        }
478                    }
479                }));
480
481                // Store the handle so it doesn't get dropped
482                _batched_sending_task = Some(batched_sending_handle);
483            }
484
485            // Wait for a response on the channel
486            request
487                .receiver
488                .clone()
489                .recv()
490                .await
491                .map_err(|_| RequestError::Other(anyhow!("channel was closed")))
492        })
493        .await
494        .map_err(|_| RequestError::Timeout)
495        .and_then(|result| result)
496        .and_then(|result| {
497            result.downcast::<O>().map_err(|e| {
498                RequestError::Other(anyhow::anyhow!(
499                    "failed to downcast response to expected type: {e:?}"
500                ))
501            })
502        })
503        .map(|result| Arc::unwrap_or_clone(result))
504    }
505
506    /// The task responsible for receiving messages from the receiver and handling them
507    async fn receiving_task(self: Arc<Self>, mut receiver: R) {
508        // Upper bound the number of outgoing and incoming responses
509        let mut incoming_requests = NamedSemaphore::new(
510            self.config.max_incoming_requests_per_key,
511            Some(self.config.max_incoming_requests),
512        );
513        // process 1 response per key, with a global maximum of [`config.max_incoming_responses`] responses
514        let mut incoming_responses =
515            NamedSemaphore::new(1, Some(self.config.max_incoming_responses));
516
517        // While the receiver is open, we receive messages and handle them
518        loop {
519            // Try to receive a message
520            match receiver.receive_message().await {
521                Ok(message) => {
522                    // Deserialize the message, warning if it fails
523                    let message = match Message::from_bytes(&message) {
524                        Ok(message) => message,
525                        Err(e) => {
526                            warn!("Received invalid message: {e:#}");
527                            continue;
528                        },
529                    };
530
531                    // Handle the message based on its type
532                    match message {
533                        Message::Request(request_message) => {
534                            self.handle_request(request_message, &mut incoming_requests);
535                        },
536                        Message::Response(response_message) => {
537                            self.handle_response(response_message, &mut incoming_responses);
538                        },
539                    }
540                },
541                // An error here means the receiver will _NEVER_ receive any more messages
542                Err(e) => {
543                    error!("Request/response receive task exited: {e:#}");
544                    return;
545                },
546            }
547        }
548    }
549
550    /// Handle a request sent to us
551    fn handle_request(
552        self: &Arc<Self>,
553        request_message: RequestMessage<Req, K>,
554        incoming_requests: &mut IncomingRequests<K>,
555    ) {
556        trace!("Handling request {:?}", request_message);
557
558        // Spawn a task to:
559        // - Validate the request
560        // - Derive the response data (check if we have it)
561        // - Send the response to the requester
562        let self_clone = Arc::clone(self);
563
564        // Attempt to acquire a permit for the request. Warn if there are too many requests currently being processed
565        // either globally or per-key
566        let permit = incoming_requests.try_acquire(request_message.public_key.clone());
567        match permit {
568            Ok(ref permit) => permit,
569            Err(NamedSemaphoreError::PerKeyLimitReached) => {
570                info!(
571                    "Failed to process request from {}: too many requests from the same key are \
572                     already being processed",
573                    request_message.public_key
574                );
575                return;
576            },
577            Err(NamedSemaphoreError::GlobalLimitReached) => {
578                info!(
579                    "Failed to process request from {}: too many requests are already being \
580                     processed",
581                    request_message.public_key
582                );
583                return;
584            },
585        };
586
587        tokio::spawn(async move {
588            let result = timeout(self_clone.config.incoming_request_timeout, async move {
589                // Validate the request message. This includes:
590                // - Checking the signature and making sure it's valid
591                // - Checking the timestamp and making sure it's not too old
592                // - Calling the request's application-specific validation function
593                request_message
594                    .validate(self_clone.config.incoming_request_ttl)
595                    .await
596                    .with_context(|| "failed to validate request")?;
597
598                // Try to fetch the response data from the data source
599                let response = self_clone
600                    .data_source
601                    .derive_response_for(&request_message.request)
602                    .await
603                    .with_context(|| "failed to derive response for request")?;
604
605                // Create the response message and serialize it
606                let response = Bytes::from(
607                    Message::Response::<Req, K>(ResponseMessage {
608                        request_hash: blake3::hash(&request_message.request.to_bytes()?),
609                        response,
610                    })
611                    .to_bytes()
612                    .with_context(|| "failed to serialize response message")?,
613                );
614
615                // Send the response to the requester
616                self_clone
617                    .sender
618                    .send_direct_message(&response, request_message.public_key)
619                    .await
620                    .with_context(|| "failed to send response to requester")?;
621
622                // Drop the permit
623                _ = permit;
624                drop(permit);
625
626                Ok::<(), anyhow::Error>(())
627            })
628            .await
629            .map_err(|_| anyhow::anyhow!("timed out while sending response"))
630            .and_then(|result| result);
631
632            if let Err(e) = result {
633                debug!("Failed to send response to requester: {e:#}");
634            }
635        });
636    }
637
638    async fn wait_for_permit(
639        incoming_responses: &mut IncomingResponses,
640        request_hash: RequestHash,
641        outgoing_requests: &OutgoingRequestsMap<Req>,
642    ) -> Result<NamedSemaphorePermit<RequestHash>, NamedSemaphoreError> {
643        while outgoing_requests.read().await.contains_key(&request_hash) {
644            let permit = incoming_responses.try_acquire(request_hash);
645            let permit = match permit {
646                Ok(permit) => permit,
647                Err(NamedSemaphoreError::PerKeyLimitReached) => {
648                    tokio::time::sleep(Duration::from_millis(100)).await;
649                    continue;
650                },
651                Err(NamedSemaphoreError::GlobalLimitReached) => {
652                    warn!(
653                        "Failed to process response: too many responses are already being \
654                         processed"
655                    );
656                    return Err(NamedSemaphoreError::GlobalLimitReached);
657                },
658            };
659            return Ok(permit);
660        }
661        Err(NamedSemaphoreError::GlobalLimitReached)
662    }
663
664    /// Handle a response sent to us
665    fn handle_response(
666        self: &Arc<Self>,
667        response: ResponseMessage<Req>,
668        incoming_responses: &mut IncomingResponses,
669    ) {
670        trace!("Handling response {response:?}");
671
672        let outgoing_requests_clone = self.outgoing_requests.clone();
673        let mut incoming_responses_clone = incoming_responses.clone();
674
675        // Spawn a task to validate the response and send it to the requester (us)
676        let response_validate_timeout = self.config.incoming_response_timeout;
677        tokio::spawn(async move {
678            if timeout(response_validate_timeout, async move {
679                // Attempt to acquire a permit for the request. Warn if there are too many responses currently being processed
680                let permit = Self::wait_for_permit(
681                    &mut incoming_responses_clone,
682                    response.request_hash,
683                    &outgoing_requests_clone,
684                )
685                .await;
686
687                let Ok(permit) = permit else {
688                    warn!(
689                        "Failed to process response: too many responses are already being \
690                         processed"
691                    );
692                    return;
693                };
694                // Get the entry in the map, ignoring it if it doesn't exist
695                let Some(outgoing_request) = outgoing_requests_clone
696                    .read()
697                    .await
698                    .get(&response.request_hash)
699                    .cloned()
700                    .and_then(|r| r.upgrade())
701                else {
702                    return;
703                };
704                // Make sure the response is valid for the given request
705                let validation_result = match (outgoing_request.response_validation_fn)(
706                    &outgoing_request.request,
707                    response.response,
708                )
709                .await
710                {
711                    Ok(validation_result) => validation_result,
712                    Err(e) => {
713                        debug!("Received invalid response: {e:#}");
714                        return;
715                    },
716                };
717
718                // Send the response to the requester (the user of [`RequestResponse::request`])
719                let _ = outgoing_request.sender.try_broadcast(validation_result);
720
721                outgoing_requests_clone
722                    .write()
723                    .await
724                    .remove(&response.request_hash);
725
726                // Drop the permit
727                _ = permit;
728                drop(permit);
729            })
730            .await
731            .is_err()
732            {
733                warn!("Timed out while validating response");
734            }
735        });
736    }
737}
738
739/// An outgoing request. This is what we use to track a request and its corresponding response
740/// in the protocol
741#[derive(Clone, Deref)]
742pub struct OutgoingRequest<R: Request>(Arc<OutgoingRequestInner<R>>);
743
744/// The inner implementation of an outgoing request
745pub struct OutgoingRequestInner<R: Request> {
746    /// The sender to use for the protocol
747    sender: async_broadcast::Sender<ThreadSafeAny>,
748    /// The receiver to use for the protocol
749    receiver: async_broadcast::Receiver<ThreadSafeAny>,
750
751    /// The request that we are waiting for a response to
752    request: R,
753
754    /// The function used to validate the response
755    response_validation_fn: ResponseValidationFn<R>,
756}
757
758#[cfg(test)]
759mod tests {
760    use std::{
761        collections::HashMap,
762        sync::{Mutex, atomic::AtomicBool},
763    };
764
765    use async_trait::async_trait;
766    use hotshot_types::signature_key::{BLSPrivKey, BLSPubKey};
767    use rand::Rng;
768    use tokio::{sync::mpsc, task::JoinSet};
769
770    use super::*;
771
772    /// A test sender that has a list of all the participants in the network
773    #[derive(Clone)]
774    pub struct TestSender {
775        network: Arc<HashMap<BLSPubKey, mpsc::Sender<Bytes>>>,
776    }
777
778    /// An implementation of the [`Sender`] trait for the [`TestSender`] type
779    #[async_trait]
780    impl Sender<BLSPubKey> for TestSender {
781        async fn send_direct_message(&self, message: &Bytes, recipient: BLSPubKey) -> Result<()> {
782            self.network
783                .get(&recipient)
784                .ok_or(anyhow::anyhow!("recipient not found"))?
785                .send(Arc::clone(message))
786                .await
787                .map_err(|_| anyhow::anyhow!("failed to send message"))?;
788
789            Ok(())
790        }
791
792        async fn send_broadcast_message(&self, message: &Bytes) -> Result<()> {
793            for sender in self.network.values() {
794                sender
795                    .send(Arc::clone(message))
796                    .await
797                    .map_err(|_| anyhow::anyhow!("failed to send message"))?;
798            }
799            Ok(())
800        }
801    }
802
803    // Implement the [`RecipientSource`] trait for the [`TestSender`] type
804    #[async_trait]
805    impl RecipientSource<TestRequest, BLSPubKey> for TestSender {
806        async fn get_expected_responders(&self, _request: &TestRequest) -> Result<Vec<BLSPubKey>> {
807            // Get all the participants in the network
808            Ok(self.network.keys().copied().collect())
809        }
810    }
811
812    // Create a test request that is just some bytes
813    #[derive(Clone, Debug)]
814    struct TestRequest(Vec<u8>);
815
816    // Implement the [`Serializable`] trait for the [`TestRequest`] type
817    impl Serializable for TestRequest {
818        fn to_bytes(&self) -> Result<Vec<u8>> {
819            Ok(self.0.clone())
820        }
821
822        fn from_bytes(bytes: &[u8]) -> Result<Self> {
823            Ok(TestRequest(bytes.to_vec()))
824        }
825    }
826
827    // Implement the [`Request`] trait for the [`TestRequest`] type
828    #[async_trait]
829    impl Request for TestRequest {
830        type Response = Vec<u8>;
831        async fn validate(&self) -> Result<()> {
832            Ok(())
833        }
834    }
835
836    // Create a test data source that pretends to have the data or not
837    #[derive(Clone)]
838    struct TestDataSource {
839        /// Whether we have the data or not
840        has_data: bool,
841        /// The time at which the data will be available if we have it
842        data_available_time: Instant,
843
844        /// Whether or not the data will be taken once served
845        take_data: bool,
846        /// Whether or not the data has been taken
847        taken: Arc<AtomicBool>,
848    }
849
850    #[async_trait]
851    impl DataSource<TestRequest> for TestDataSource {
852        async fn derive_response_for(&self, request: &TestRequest) -> Result<Vec<u8>> {
853            // Return a response if we hit the hit rate
854            if self.has_data && Instant::now() >= self.data_available_time {
855                if self.take_data && !self.taken.swap(true, std::sync::atomic::Ordering::Relaxed) {
856                    return Err(anyhow::anyhow!("data already taken"));
857                }
858                Ok(blake3::hash(&request.0).as_bytes().to_vec())
859            } else {
860                Err(anyhow::anyhow!("did not have the data"))
861            }
862        }
863    }
864
865    /// Create and return a default protocol configuration
866    fn default_protocol_config() -> RequestResponseConfig {
867        RequestResponseConfig {
868            incoming_request_ttl: Duration::from_secs(40),
869            incoming_request_timeout: Duration::from_secs(40),
870            request_batch_size: 10,
871            request_batch_interval: Duration::from_millis(100),
872            max_incoming_requests: 10,
873            max_incoming_requests_per_key: 1,
874            incoming_response_timeout: Duration::from_secs(1),
875            max_incoming_responses: 5,
876        }
877    }
878
879    /// Create fully connected test networks with `num_participants` participants
880    fn create_participants(
881        num: usize,
882    ) -> Vec<(TestSender, mpsc::Receiver<Bytes>, (BLSPubKey, BLSPrivKey))> {
883        // The entire network
884        let mut network = HashMap::new();
885
886        // All receivers in the network
887        let mut receivers = Vec::new();
888
889        // All keypairs in the network
890        let mut keypairs = Vec::new();
891
892        // For each participant,
893        for i in 0..num {
894            // Create a unique `BLSPubKey`
895            let (public_key, private_key) =
896                BLSPubKey::generated_from_seed_indexed([2; 32], i.try_into().unwrap());
897
898            // Add the keypair to the list
899            keypairs.push((public_key, private_key));
900
901            // Create a channel for sending and receiving messages
902            let (sender, receiver) = mpsc::channel::<Bytes>(100);
903
904            // Add the participant to the network
905            network.insert(public_key, sender);
906
907            // Add the receiver to the list of receivers
908            receivers.push(receiver);
909        }
910
911        // Create a test sender from the network
912        let sender = TestSender {
913            network: Arc::new(network),
914        };
915
916        // Return all senders and receivers
917        receivers
918            .into_iter()
919            .zip(keypairs)
920            .map(|(r, k)| (sender.clone(), r, k))
921            .collect()
922    }
923
924    /// The configuration for an integration test
925    #[derive(Clone)]
926    struct IntegrationTestConfig {
927        /// The request response protocol configuration
928        request_response_config: RequestResponseConfig,
929        /// The number of participants in the network
930        num_participants: usize,
931        /// The number of participants that have the data
932        num_participants_with_data: usize,
933        /// The timeout for the requests
934        request_timeout: Duration,
935        /// The delay before the nodes have the data available
936        data_available_delay: Duration,
937    }
938
939    /// The result of an integration test
940    struct IntegrationTestResult {
941        /// The number of nodes that received a response
942        num_succeeded: usize,
943    }
944
945    /// Run an integration test with the given parameters
946    async fn run_integration_test(config: IntegrationTestConfig) -> IntegrationTestResult {
947        // Create a fully connected network with `num_participants` participants
948        let participants = create_participants(config.num_participants);
949
950        // Create a join set to wait for all the tasks to finish
951        let mut join_set = JoinSet::new();
952
953        // We need to keep these here so they don't get dropped
954        let handles = Arc::new(Mutex::new(Vec::new()));
955
956        // For each one, create a new [`RequestResponse`] protocol
957        for (i, (sender, receiver, (public_key, private_key))) in
958            participants.into_iter().enumerate()
959        {
960            let config_clone = config.request_response_config.clone();
961            let handles_clone = Arc::clone(&handles);
962            join_set.spawn(async move {
963                let protocol = RequestResponse::new(
964                    config_clone,
965                    sender.clone(),
966                    receiver,
967                    sender,
968                    TestDataSource {
969                        has_data: i < config.num_participants_with_data,
970                        data_available_time: Instant::now() + config.data_available_delay,
971                        take_data: false,
972                        taken: Arc::new(AtomicBool::new(false)),
973                    },
974                );
975
976                // Add the handle to the handles list so it doesn't get dropped and
977                // cancelled
978                #[allow(clippy::used_underscore_binding)]
979                handles_clone
980                    .lock()
981                    .unwrap()
982                    .push(Arc::clone(&protocol._receiving_task_handle));
983
984                // Create a random request
985                let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
986
987                // Get the hash of the request
988                let request_hash = blake3::hash(&request.0).as_bytes().to_vec();
989
990                // Create a new request message
991                let request = RequestMessage::new_signed(&public_key, &private_key, &request)
992                    .expect("failed to create request message");
993
994                // Request the data from the protocol
995                let response = protocol
996                    .request(
997                        request,
998                        RequestType::Batched,
999                        config.request_timeout,
1000                        |_request, response| async move { Ok(response) },
1001                    )
1002                    .await?;
1003
1004                // Make sure the response is the hash of the request
1005                assert_eq!(response, request_hash);
1006
1007                Ok::<(), anyhow::Error>(())
1008            });
1009        }
1010
1011        // Wait for all the tasks to finish
1012        let mut num_succeeded = config.num_participants;
1013        while let Some(result) = join_set.join_next().await {
1014            if result.is_err() || result.unwrap().is_err() {
1015                num_succeeded -= 1;
1016            }
1017        }
1018
1019        IntegrationTestResult { num_succeeded }
1020    }
1021
1022    /// Test the integration of the protocol with 50% of the participants having the data
1023    #[tokio::test(flavor = "multi_thread")]
1024    async fn test_integration_50_0s() {
1025        // Build a config
1026        let config = IntegrationTestConfig {
1027            request_response_config: default_protocol_config(),
1028            num_participants: 100,
1029            num_participants_with_data: 50,
1030            request_timeout: Duration::from_secs(40),
1031            data_available_delay: Duration::from_secs(0),
1032        };
1033
1034        // Run the test, making sure all the requests succeed
1035        let result = run_integration_test(config).await;
1036        assert_eq!(result.num_succeeded, 100);
1037    }
1038
1039    /// Test the integration of the protocol when nobody has the data. Make sure we don't
1040    /// get any responses
1041    #[tokio::test(flavor = "multi_thread")]
1042    async fn test_integration_0() {
1043        // Build a config
1044        let config = IntegrationTestConfig {
1045            request_response_config: default_protocol_config(),
1046            num_participants: 100,
1047            num_participants_with_data: 0,
1048            request_timeout: Duration::from_secs(40),
1049            data_available_delay: Duration::from_secs(0),
1050        };
1051
1052        // Run the test
1053        let result = run_integration_test(config).await;
1054
1055        // Make sure all the requests succeeded
1056        assert_eq!(result.num_succeeded, 0);
1057    }
1058
1059    /// Test the integration of the protocol when one node has the data after
1060    /// a delay of 1s
1061    #[tokio::test(flavor = "multi_thread")]
1062    async fn test_integration_1_1s() {
1063        // Build a config
1064        let config = IntegrationTestConfig {
1065            request_response_config: default_protocol_config(),
1066            num_participants: 100,
1067            num_participants_with_data: 1,
1068            request_timeout: Duration::from_secs(40),
1069            data_available_delay: Duration::from_secs(2),
1070        };
1071
1072        // Run the test
1073        let result = run_integration_test(config).await;
1074
1075        // Make sure all the requests succeeded
1076        assert_eq!(result.num_succeeded, 100);
1077    }
1078
1079    /// Test that we can join an existing request for the same data and get the same (single) response
1080    #[tokio::test(flavor = "multi_thread")]
1081    async fn test_join_existing_request() {
1082        // Build a config
1083        let config = default_protocol_config();
1084
1085        // Create two participants
1086        let mut participants = Vec::new();
1087
1088        for (sender, receiver, (public_key, private_key)) in create_participants(2) {
1089            // For each, create a new [`RequestResponse`] protocol
1090            let protocol = RequestResponse::new(
1091                config.clone(),
1092                sender.clone(),
1093                receiver,
1094                sender,
1095                TestDataSource {
1096                    take_data: true,
1097                    has_data: true,
1098                    data_available_time: Instant::now() + Duration::from_secs(2),
1099                    taken: Arc::new(AtomicBool::new(false)),
1100                },
1101            );
1102
1103            // Add the participants to the list
1104            participants.push((protocol, public_key, private_key));
1105        }
1106
1107        // Take the first participant
1108        let one = Arc::new(participants.remove(0));
1109
1110        // Create the request that they should all be able to join on
1111        let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
1112
1113        // Create a join set to wait for all the tasks to finish
1114        let mut join_set = JoinSet::new();
1115
1116        // Make 10 requests with the same hash
1117        for _ in 0..10 {
1118            // Clone the first participant
1119            let one_clone = Arc::clone(&one);
1120
1121            // Clone the request
1122            let request_clone = request.clone();
1123
1124            // Spawn a task to request the data
1125            join_set.spawn(async move {
1126                // Create a new, signed request message
1127                let request_message =
1128                    RequestMessage::new_signed(&one_clone.1, &one_clone.2, &request_clone)?;
1129
1130                // Start requesting it
1131                one_clone
1132                    .0
1133                    .request(
1134                        request_message,
1135                        RequestType::Batched,
1136                        Duration::from_secs(20),
1137                        |_request, response| async move { Ok(response) },
1138                    )
1139                    .await?;
1140
1141                Ok::<(), anyhow::Error>(())
1142            });
1143        }
1144
1145        // Wait for all the tasks to finish, making sure they all succeed
1146        while let Some(result) = join_set.join_next().await {
1147            result
1148                .expect("failed to join task")
1149                .expect("failed to request data");
1150        }
1151    }
1152}