Skip to main content

request_response/
lib.rs

1//! This crate contains a general request-response protocol. It is used to send requests to
2//! a set of recipients and wait for responses.
3
4use std::{
5    any::Any,
6    collections::HashMap,
7    future::Future,
8    marker::PhantomData,
9    pin::Pin,
10    sync::{Arc, Weak},
11    time::{Duration, Instant},
12};
13
14use anyhow::{Context, Result, anyhow};
15use async_lock::RwLock;
16use data_source::DataSource;
17use derive_more::derive::Deref;
18use hotshot_types::traits::signature_key::SignatureKey;
19use message::{Message, RequestMessage, ResponseMessage};
20use network::{Bytes, Receiver, Sender};
21use rand::seq::SliceRandom;
22use recipient_source::RecipientSource;
23use request::Request;
24use tokio::{
25    spawn,
26    time::{sleep, timeout},
27};
28use tokio_util::task::AbortOnDropHandle;
29use tracing::{debug, error, info, trace, warn};
30use util::{BoundedVecDeque, NamedSemaphore, NamedSemaphoreError};
31
32use crate::util::NamedSemaphorePermit;
33
34/// The data source trait. Is what we use to derive the response data for a request
35pub mod data_source;
36/// The message type. Is the base type for all messages in the request-response protocol
37pub mod message;
38/// The network traits. Is what we use to send and receive messages over the network as
39/// the protocol
40pub mod network;
41/// The recipient source trait. Is what we use to get the recipients that a specific message should
42/// expect responses from
43pub mod recipient_source;
44/// The request trait. Is what we use to define a request and a corresponding response type
45pub mod request;
46/// Utility types and functions
47mod util;
48
49/// A type alias for the hash of a request
50pub type RequestHash = blake3::Hash;
51
52/// A type alias for the outgoing requests map
53pub type OutgoingRequestsMap<Req> =
54    Arc<RwLock<HashMap<RequestHash, Weak<OutgoingRequestInner<Req>>>>>;
55
56/// A type alias for the list of tasks that are responding to requests
57pub type IncomingRequests<K> = NamedSemaphore<K>;
58
59/// A type alias for the list of tasks that are validating incoming responses
60pub type IncomingResponses = NamedSemaphore<RequestHash>;
61
62/// The type of request to make
63#[derive(PartialEq, Eq, Clone, Copy)]
64pub enum RequestType {
65    /// A request that can be satisfied by a single participant,
66    /// and as such will be batched to a few participants at a time
67    /// until one succeeds
68    Batched,
69    /// A request that needs most or all participants to respond,
70    /// and as such will be broadcasted to all participants
71    Broadcast,
72}
73
74/// The errors that can occur when making a request for data
75#[derive(thiserror::Error, Debug)]
76pub enum RequestError {
77    /// The request timed out
78    #[error("request timed out")]
79    Timeout,
80    /// The request was invalid
81    #[error("request was invalid")]
82    InvalidRequest(anyhow::Error),
83    /// Other errors
84    #[error("other error")]
85    Other(anyhow::Error),
86}
87
88/// A trait for serializing and deserializing a type to and from a byte array. [`Request`] types and
89/// [`Response`] types will need to implement this trait
90pub trait Serializable: Sized {
91    /// Serialize the type to a byte array. If this is for a [`Request`] and your [`Request`] type
92    /// is represented as an enum, please make sure that you serialize it with a unique type ID. Otherwise,
93    /// you may end up with collisions as the request hash is used as a unique identifier
94    ///
95    /// # Errors
96    /// - If the type cannot be serialized to a byte array
97    fn to_bytes(&self) -> Result<Vec<u8>>;
98
99    /// Deserialize the type from a byte array
100    ///
101    /// # Errors
102    /// - If the byte array is not a valid representation of the type
103    fn from_bytes(bytes: &[u8]) -> Result<Self>;
104}
105
106/// The underlying configuration for the request-response protocol
107#[derive(Clone)]
108pub struct RequestResponseConfig {
109    /// The timeout for incoming requests. Do not respond to a request after this threshold
110    /// has passed.
111    pub incoming_request_ttl: Duration,
112    /// The maximum amount of time we will spend trying to both derive a response for a request and
113    /// send the response over the wire.
114    pub incoming_request_timeout: Duration,
115    /// The maximum amount of time we will spend trying to validate a response. This is used to prevent
116    /// an attack where a malicious participant sends us a bunch of requests that take a long time to
117    /// validate.
118    pub incoming_response_timeout: Duration,
119    /// The batch size for outgoing requests. This is the number of request messages that we will
120    /// send out at a time for a single request before waiting for the [`request_batch_interval`].
121    pub request_batch_size: usize,
122    /// The time to wait (per request) between sending out batches of request messages
123    pub request_batch_interval: Duration,
124    /// The maximum (global) number of incoming requests that can be processed at any given time.
125    pub max_incoming_requests: usize,
126    /// The maximum number of incoming requests that can be processed for a single key at any given time.
127    pub max_incoming_requests_per_key: usize,
128    /// The maximum (global) number of incoming responses that can be processed at any given time.
129    /// We need this because responses coming in need to be validated [asynchronously] that they
130    /// satisfy the request they are responding to
131    pub max_incoming_responses: usize,
132}
133
134/// A protocol that allows for request-response communication. Is cheaply cloneable, so there is no
135/// need to wrap it in an `Arc`
136#[derive(Deref)]
137pub struct RequestResponse<
138    S: Sender<K>,
139    R: Receiver,
140    Req: Request,
141    RS: RecipientSource<Req, K>,
142    DS: DataSource<Req>,
143    K: SignatureKey + 'static,
144> {
145    #[deref]
146    /// The inner implementation of the request-response protocol
147    pub inner: Arc<RequestResponseInner<S, R, Req, RS, DS, K>>,
148    /// A handle to the receiving task. This will automatically get cancelled when the protocol is dropped
149    _receiving_task_handle: Arc<AbortOnDropHandle<()>>,
150}
151
152/// We need to manually implement the `Clone` trait for this type because deriving
153/// `Deref` will cause an issue where it tries to clone the inner field instead
154impl<
155    S: Sender<K>,
156    R: Receiver,
157    Req: Request,
158    RS: RecipientSource<Req, K>,
159    DS: DataSource<Req>,
160    K: SignatureKey + 'static,
161> Clone for RequestResponse<S, R, Req, RS, DS, K>
162{
163    fn clone(&self) -> Self {
164        Self {
165            inner: Arc::clone(&self.inner),
166            _receiving_task_handle: Arc::clone(&self._receiving_task_handle),
167        }
168    }
169}
170
171impl<
172    S: Sender<K>,
173    R: Receiver,
174    Req: Request,
175    RS: RecipientSource<Req, K>,
176    DS: DataSource<Req>,
177    K: SignatureKey + 'static,
178> RequestResponse<S, R, Req, RS, DS, K>
179{
180    /// Create a new [`RequestResponseProtocol`]
181    pub fn new(
182        // The configuration for the protocol
183        config: RequestResponseConfig,
184        // The network sender that [`RequestResponseProtocol`] will use to send messages
185        sender: S,
186        // The network receiver that [`RequestResponseProtocol`] will use to receive messages
187        receiver: R,
188        // The recipient source that [`RequestResponseProtocol`] will use to get the recipients
189        // that a specific message should expect responses from
190        recipient_source: RS,
191        // The [response] data source that [`RequestResponseProtocol`] will use to derive the
192        // response data for a specific request
193        data_source: DS,
194    ) -> Self {
195        // Create the outgoing requests map
196        let outgoing_requests = OutgoingRequestsMap::default();
197
198        // Create the inner implementation
199        let inner = Arc::new(RequestResponseInner {
200            config,
201            sender,
202            recipient_source,
203            data_source,
204            outgoing_requests,
205            phantom_data: PhantomData,
206        });
207
208        // Start the task that receives messages and handles them. This will automatically get cancelled
209        // when the protocol is dropped
210        let inner_clone = Arc::clone(&inner);
211        let receive_task_handle =
212            AbortOnDropHandle::new(tokio::spawn(inner_clone.receiving_task(receiver)));
213
214        // Return the protocol
215        Self {
216            inner,
217            _receiving_task_handle: Arc::new(receive_task_handle),
218        }
219    }
220}
221
222/// A type alias for an `Arc<dyn Any + Send + Sync + 'static>`
223type ThreadSafeAny = Arc<dyn Any + Send + Sync + 'static>;
224
225/// A type alias for the future that validates a response
226type ResponseValidationFuture =
227    Pin<Box<dyn Future<Output = Result<ThreadSafeAny, anyhow::Error>> + Send + Sync + 'static>>;
228
229/// A type alias for the function that returns the above future
230type ResponseValidationFn<R> =
231    Box<dyn Fn(&R, <R as Request>::Response) -> ResponseValidationFuture + Send + Sync + 'static>;
232
233/// The inner implementation for the request-response protocol
234pub struct RequestResponseInner<
235    S: Sender<K>,
236    R: Receiver,
237    Req: Request,
238    RS: RecipientSource<Req, K>,
239    DS: DataSource<Req>,
240    K: SignatureKey + 'static,
241> {
242    /// The configuration of the protocol
243    config: RequestResponseConfig,
244    /// The sender to use for the protocol
245    pub sender: S,
246    /// The recipient source to use for the protocol
247    pub recipient_source: RS,
248    /// The data source to use for the protocol
249    data_source: DS,
250    /// The map of currently active, outgoing requests
251    outgoing_requests: OutgoingRequestsMap<Req>,
252    /// Phantom data to help with type inference
253    phantom_data: PhantomData<(K, R, Req, DS)>,
254}
255impl<
256    S: Sender<K>,
257    R: Receiver,
258    Req: Request,
259    RS: RecipientSource<Req, K>,
260    DS: DataSource<Req>,
261    K: SignatureKey + 'static,
262> RequestResponseInner<S, R, Req, RS, DS, K>
263{
264    /// Request something from the protocol indefinitely until we get a response
265    /// or there was a critical error (e.g. the request could not be signed)
266    ///
267    /// # Errors
268    /// - If the request was invalid
269    /// - If there was a critical error (e.g. the channel was closed)
270    pub async fn request_indefinitely<F, Fut, O>(
271        self: &Arc<Self>,
272        public_key: &K,
273        private_key: &K::PrivateKey,
274        // The type of request to make
275        request_type: RequestType,
276        // The estimated TTL of other participants. This is used to decide when to
277        // stop making requests and sign a new one
278        estimated_request_ttl: Duration,
279        // The request to make
280        request: Req,
281        // The response validation function
282        response_validation_fn: F,
283    ) -> std::result::Result<O, RequestError>
284    where
285        F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
286        Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
287        O: Send + Sync + 'static + Clone,
288    {
289        loop {
290            // Sign a request message
291            let request_message = RequestMessage::new_signed(public_key, private_key, &request)
292                .map_err(|e| {
293                    RequestError::InvalidRequest(anyhow::anyhow!(
294                        "failed to sign request message: {e}"
295                    ))
296                })?;
297
298            // Request the data, handling the errors appropriately
299            match self
300                .request(
301                    request_message,
302                    request_type,
303                    estimated_request_ttl,
304                    response_validation_fn.clone(),
305                )
306                .await
307            {
308                Ok(response) => return Ok(response),
309                Err(RequestError::Timeout) => continue,
310                Err(e) => return Err(e),
311            }
312        }
313    }
314
315    /// Request something from the protocol and wait for the response. This function
316    /// will join with an existing request for the same data (determined by `Blake3` hash),
317    /// however both will make requests until the timeout is reached
318    ///
319    /// # Errors
320    /// - If the request times out
321    /// - If the channel is closed (this is an internal error)
322    /// - If the request we sign is invalid
323    pub async fn request<F, Fut, O>(
324        self: &Arc<Self>,
325        request_message: RequestMessage<Req, K>,
326        request_type: RequestType,
327        timeout_duration: Duration,
328        response_validation_fn: F,
329    ) -> std::result::Result<O, RequestError>
330    where
331        F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
332        Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
333        O: Send + Sync + 'static + Clone,
334    {
335        timeout(timeout_duration, async move {
336            // Calculate the hash of the request
337            let request_hash = blake3::hash(&request_message.request.to_bytes().map_err(|e| {
338                RequestError::InvalidRequest(anyhow::anyhow!(
339                    "failed to serialize request message: {e}"
340                ))
341            })?);
342
343            let request = {
344                // Get a write lock on the outgoing requests map
345                let mut outgoing_requests_write = self.outgoing_requests.write().await;
346
347                // Conditionally get the outgoing request, creating a new one if it doesn't exist or if
348                // the existing one has been dropped and not yet removed
349                if let Some(outgoing_request) = outgoing_requests_write
350                    .get(&request_hash)
351                    .and_then(Weak::upgrade)
352                {
353                    OutgoingRequest(outgoing_request)
354                } else {
355                    // Create a new broadcast channel for the response
356                    let (sender, receiver) = async_broadcast::broadcast(1);
357
358                    // Modify the response validation function to return an `Arc<dyn Any>`
359                    let response_validation_fn =
360                        Box::new(move |request: &Req, response: Req::Response| {
361                            let fut = response_validation_fn(request, response);
362                            Box::pin(
363                                async move { fut.await.map(|ok| Arc::new(ok) as ThreadSafeAny) },
364                            ) as ResponseValidationFuture
365                        });
366
367                    // Create a new outgoing request
368                    let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
369                        sender,
370                        receiver,
371                        response_validation_fn,
372                        request: request_message.request.clone(),
373                    }));
374
375                    // Write the new outgoing request to the map
376                    outgoing_requests_write
377                        .insert(request_hash, Arc::downgrade(&outgoing_request.0));
378
379                    // Return the new outgoing request
380                    outgoing_request
381                }
382            };
383
384            // Create a request message and serialize it
385            let message = Bytes::from(
386                Message::Request(request_message.clone())
387                    .to_bytes()
388                    .map_err(|e| {
389                        RequestError::InvalidRequest(anyhow::anyhow!(
390                            "failed to serialize request message: {e}"
391                        ))
392                    })?,
393            );
394
395            // Create a place to put the handle for the batched sending task. We need this because
396            // otherwise it gets dropped when the closure goes out of scope, instead of when the function
397            // gets cancelled or returns
398            let mut _batched_sending_task = None;
399
400            // Match on the type of request
401            if request_type == RequestType::Broadcast {
402                trace!("Sending request {request_message:?} to all participants");
403
404                // If the message is a broadcast request, just send it to all participants
405                self.sender
406                    .send_broadcast_message(&message)
407                    .await
408                    .map_err(|e| {
409                        RequestError::Other(anyhow::anyhow!(
410                            "failed to send broadcast message: {e}"
411                        ))
412                    })?;
413            } else {
414                // If the message is a batched request, we need to batch it with other requests
415
416                // Get the recipients that the request should expect responses from. Shuffle them so
417                // that we don't always send to the same recipients in the same order
418                let mut recipients = self
419                    .recipient_source
420                    .get_expected_responders(&request_message.request)
421                    .await
422                    .map_err(|e| {
423                        RequestError::InvalidRequest(anyhow::anyhow!(
424                            "failed to get expected responders for request: {e}"
425                        ))
426                    })?;
427                recipients.shuffle(&mut rand::thread_rng());
428
429                // Get the current time so we can check when the timeout has elapsed
430                let start_time = Instant::now();
431
432                // Spawn a task that sends out requests to the network
433                let self_clone = Arc::clone(self);
434                let batched_sending_handle = AbortOnDropHandle::new(spawn(async move {
435                    // Create a bounded queue for the outgoing requests. We use this to make sure
436                    // we have less than [`config.request_batch_size`] requests in flight at any time.
437                    //
438                    // When newer requests are added, older ones are removed from the queue. Because we use
439                    // `AbortOnDropHandle`, the older ones will automatically get cancelled
440                    let mut outgoing_requests =
441                        BoundedVecDeque::new(self_clone.config.request_batch_size);
442
443                    // While the timeout hasn't elapsed, send out requests to the network
444                    while start_time.elapsed() < timeout_duration {
445                        // Send out requests to the network in their own separate tasks
446                        for recipient_batch in
447                            recipients.chunks(self_clone.config.request_batch_size)
448                        {
449                            for recipient in recipient_batch {
450                                // Clone ourselves, the message, and the recipient so they can be moved
451                                let self_clone = Arc::clone(&self_clone);
452                                let request_message_clone = request_message.clone();
453                                let recipient_clone = recipient.clone();
454                                let message_clone = Arc::clone(&message);
455
456                                // Spawn the task that sends the request to the participant
457                                let individual_sending_task = spawn(async move {
458                                    trace!(
459                                        "Sending request {request_message_clone:?} to \
460                                         {recipient_clone:?}"
461                                    );
462
463                                    let _ = self_clone
464                                        .sender
465                                        .send_direct_message(&message_clone, recipient_clone)
466                                        .await;
467                                });
468
469                                // Add the sending task to the queue
470                                outgoing_requests
471                                    .push(AbortOnDropHandle::new(individual_sending_task));
472                            }
473
474                            // After we send the batch out, wait the [`config.request_batch_interval`]
475                            // before sending the next one
476                            sleep(self_clone.config.request_batch_interval).await;
477                        }
478                    }
479                }));
480
481                // Store the handle so it doesn't get dropped
482                _batched_sending_task = Some(batched_sending_handle);
483            }
484
485            // Wait for a response on the channel
486            request
487                .receiver
488                .clone()
489                .recv()
490                .await
491                .map_err(|_| RequestError::Other(anyhow!("channel was closed")))
492        })
493        .await
494        .map_err(|_| RequestError::Timeout)
495        .and_then(|result| result)
496        .and_then(|result| {
497            result.downcast::<O>().map_err(|e| {
498                RequestError::Other(anyhow::anyhow!(
499                    "failed to downcast response to expected type: {e:?}"
500                ))
501            })
502        })
503        .map(|result| Arc::unwrap_or_clone(result))
504    }
505
506    /// The task responsible for receiving messages from the receiver and handling them
507    async fn receiving_task(self: Arc<Self>, mut receiver: R) {
508        // Upper bound the number of outgoing and incoming responses
509        let mut incoming_requests = NamedSemaphore::new(
510            self.config.max_incoming_requests_per_key,
511            Some(self.config.max_incoming_requests),
512        );
513        // process 1 response per key, with a global maximum of [`config.max_incoming_responses`] responses
514        let mut incoming_responses =
515            NamedSemaphore::new(1, Some(self.config.max_incoming_responses));
516
517        // While the receiver is open, we receive messages and handle them
518        loop {
519            // Try to receive a message
520            match receiver.receive_message().await {
521                Ok(message) => {
522                    // Deserialize the message, warning if it fails
523                    let message = match Message::from_bytes(&message) {
524                        Ok(message) => message,
525                        Err(e) => {
526                            warn!("Received invalid message: {e:#}");
527                            continue;
528                        },
529                    };
530
531                    // Handle the message based on its type
532                    match message {
533                        Message::Request(request_message) => {
534                            self.handle_request(request_message, &mut incoming_requests);
535                        },
536                        Message::Response(response_message) => {
537                            self.handle_response(response_message, &mut incoming_responses);
538                        },
539                    }
540                },
541                // An error here means the receiver will _NEVER_ receive any more messages
542                Err(e) => {
543                    error!("Request/response receive task exited: {e:#}");
544                    return;
545                },
546            }
547        }
548    }
549
550    /// Handle a request sent to us
551    fn handle_request(
552        self: &Arc<Self>,
553        request_message: RequestMessage<Req, K>,
554        incoming_requests: &mut IncomingRequests<K>,
555    ) {
556        trace!("Handling request {:?}", request_message);
557
558        // Spawn a task to:
559        // - Validate the request
560        // - Derive the response data (check if we have it)
561        // - Send the response to the requester
562        let self_clone = Arc::clone(self);
563
564        // Attempt to acquire a permit for the request. Warn if there are too many requests currently being processed
565        // either globally or per-key
566        let permit = incoming_requests.try_acquire(request_message.public_key.clone());
567        match permit {
568            Ok(ref permit) => permit,
569            Err(NamedSemaphoreError::PerKeyLimitReached) => {
570                info!(
571                    "Failed to process request from {}: too many requests from the same key are \
572                     already being processed",
573                    request_message.public_key
574                );
575                return;
576            },
577            Err(NamedSemaphoreError::GlobalLimitReached) => {
578                info!(
579                    "Failed to process request from {}: too many requests are already being \
580                     processed",
581                    request_message.public_key
582                );
583                return;
584            },
585        };
586
587        tokio::spawn(async move {
588            let result = timeout(self_clone.config.incoming_request_timeout, async move {
589                // Validate the request message. This includes:
590                // - Checking the signature and making sure it's valid
591                // - Checking the timestamp and making sure it's not too old
592                // - Calling the request's application-specific validation function
593                request_message
594                    .validate(self_clone.config.incoming_request_ttl)
595                    .with_context(|| "failed to validate request")?;
596
597                // Try to fetch the response data from the data source
598                let response = self_clone
599                    .data_source
600                    .derive_response_for(&request_message.request)
601                    .await
602                    .with_context(|| "failed to derive response for request")?;
603
604                // Create the response message and serialize it
605                let response = Bytes::from(
606                    Message::Response::<Req, K>(ResponseMessage {
607                        request_hash: blake3::hash(&request_message.request.to_bytes()?),
608                        response,
609                    })
610                    .to_bytes()
611                    .with_context(|| "failed to serialize response message")?,
612                );
613
614                // Send the response to the requester
615                self_clone
616                    .sender
617                    .send_direct_message(&response, request_message.public_key)
618                    .await
619                    .with_context(|| "failed to send response to requester")?;
620
621                // Drop the permit
622                _ = permit;
623                drop(permit);
624
625                Ok::<(), anyhow::Error>(())
626            })
627            .await
628            .map_err(|_| anyhow::anyhow!("timed out while sending response"))
629            .and_then(|result| result);
630
631            if let Err(e) = result {
632                debug!("Failed to send response to requester: {e:#}");
633            }
634        });
635    }
636
637    async fn wait_for_permit(
638        incoming_responses: &mut IncomingResponses,
639        request_hash: RequestHash,
640        outgoing_requests: &OutgoingRequestsMap<Req>,
641    ) -> Result<NamedSemaphorePermit<RequestHash>, NamedSemaphoreError> {
642        while outgoing_requests.read().await.contains_key(&request_hash) {
643            let permit = incoming_responses.try_acquire(request_hash);
644            let permit = match permit {
645                Ok(permit) => permit,
646                Err(NamedSemaphoreError::PerKeyLimitReached) => {
647                    tokio::time::sleep(Duration::from_millis(100)).await;
648                    continue;
649                },
650                Err(NamedSemaphoreError::GlobalLimitReached) => {
651                    warn!(
652                        "Failed to process response: too many responses are already being \
653                         processed"
654                    );
655                    return Err(NamedSemaphoreError::GlobalLimitReached);
656                },
657            };
658            return Ok(permit);
659        }
660        Err(NamedSemaphoreError::GlobalLimitReached)
661    }
662
663    /// Handle a response sent to us
664    fn handle_response(
665        self: &Arc<Self>,
666        response: ResponseMessage<Req>,
667        incoming_responses: &mut IncomingResponses,
668    ) {
669        trace!("Handling response {response:?}");
670
671        let outgoing_requests_clone = self.outgoing_requests.clone();
672        let mut incoming_responses_clone = incoming_responses.clone();
673
674        // Spawn a task to validate the response and send it to the requester (us)
675        let response_validate_timeout = self.config.incoming_response_timeout;
676        tokio::spawn(async move {
677            if timeout(response_validate_timeout, async move {
678                // Attempt to acquire a permit for the request. Warn if there are too many responses currently being processed
679                let permit = Self::wait_for_permit(
680                    &mut incoming_responses_clone,
681                    response.request_hash,
682                    &outgoing_requests_clone,
683                )
684                .await;
685
686                let Ok(permit) = permit else {
687                    warn!(
688                        "Failed to process response: too many responses are already being \
689                         processed"
690                    );
691                    return;
692                };
693                // Get the entry in the map, ignoring it if it doesn't exist
694                let Some(outgoing_request) = outgoing_requests_clone
695                    .read()
696                    .await
697                    .get(&response.request_hash)
698                    .cloned()
699                    .and_then(|r| r.upgrade())
700                else {
701                    return;
702                };
703                // Make sure the response is valid for the given request
704                let validation_result = match (outgoing_request.response_validation_fn)(
705                    &outgoing_request.request,
706                    response.response,
707                )
708                .await
709                {
710                    Ok(validation_result) => validation_result,
711                    Err(e) => {
712                        debug!("Received invalid response: {e:#}");
713                        return;
714                    },
715                };
716
717                // Send the response to the requester (the user of [`RequestResponse::request`])
718                let _ = outgoing_request.sender.try_broadcast(validation_result);
719
720                outgoing_requests_clone
721                    .write()
722                    .await
723                    .remove(&response.request_hash);
724
725                // Drop the permit
726                _ = permit;
727                drop(permit);
728            })
729            .await
730            .is_err()
731            {
732                warn!("Timed out while validating response");
733            }
734        });
735    }
736}
737
738/// An outgoing request. This is what we use to track a request and its corresponding response
739/// in the protocol
740#[derive(Clone, Deref)]
741pub struct OutgoingRequest<R: Request>(Arc<OutgoingRequestInner<R>>);
742
743/// The inner implementation of an outgoing request
744pub struct OutgoingRequestInner<R: Request> {
745    /// The sender to use for the protocol
746    sender: async_broadcast::Sender<ThreadSafeAny>,
747    /// The receiver to use for the protocol
748    receiver: async_broadcast::Receiver<ThreadSafeAny>,
749
750    /// The request that we are waiting for a response to
751    request: R,
752
753    /// The function used to validate the response
754    response_validation_fn: ResponseValidationFn<R>,
755}
756
757#[cfg(test)]
758mod tests {
759    use std::{
760        collections::HashMap,
761        sync::{Mutex, atomic::AtomicBool},
762    };
763
764    use async_trait::async_trait;
765    use hotshot_types::signature_key::{BLSPrivKey, BLSPubKey};
766    use rand::Rng;
767    use tokio::{sync::mpsc, task::JoinSet};
768
769    use super::*;
770
771    /// A test sender that has a list of all the participants in the network
772    #[derive(Clone)]
773    pub struct TestSender {
774        network: Arc<HashMap<BLSPubKey, mpsc::Sender<Bytes>>>,
775    }
776
777    /// An implementation of the [`Sender`] trait for the [`TestSender`] type
778    #[async_trait]
779    impl Sender<BLSPubKey> for TestSender {
780        async fn send_direct_message(&self, message: &Bytes, recipient: BLSPubKey) -> Result<()> {
781            self.network
782                .get(&recipient)
783                .ok_or(anyhow::anyhow!("recipient not found"))?
784                .send(Arc::clone(message))
785                .await
786                .map_err(|_| anyhow::anyhow!("failed to send message"))?;
787
788            Ok(())
789        }
790
791        async fn send_broadcast_message(&self, message: &Bytes) -> Result<()> {
792            for sender in self.network.values() {
793                sender
794                    .send(Arc::clone(message))
795                    .await
796                    .map_err(|_| anyhow::anyhow!("failed to send message"))?;
797            }
798            Ok(())
799        }
800    }
801
802    // Implement the [`RecipientSource`] trait for the [`TestSender`] type
803    #[async_trait]
804    impl RecipientSource<TestRequest, BLSPubKey> for TestSender {
805        async fn get_expected_responders(&self, _request: &TestRequest) -> Result<Vec<BLSPubKey>> {
806            // Get all the participants in the network
807            Ok(self.network.keys().copied().collect())
808        }
809    }
810
811    // Create a test request that is just some bytes
812    #[derive(Clone, Debug)]
813    struct TestRequest(Vec<u8>);
814
815    // Implement the [`Serializable`] trait for the [`TestRequest`] type
816    impl Serializable for TestRequest {
817        fn to_bytes(&self) -> Result<Vec<u8>> {
818            Ok(self.0.clone())
819        }
820
821        fn from_bytes(bytes: &[u8]) -> Result<Self> {
822            Ok(TestRequest(bytes.to_vec()))
823        }
824    }
825
826    // Implement the [`Request`] trait for the [`TestRequest`] type
827    impl Request for TestRequest {
828        type Response = Vec<u8>;
829        fn validate(&self) -> Result<()> {
830            Ok(())
831        }
832    }
833
834    // Create a test data source that pretends to have the data or not
835    #[derive(Clone)]
836    struct TestDataSource {
837        /// Whether we have the data or not
838        has_data: bool,
839        /// The time at which the data will be available if we have it
840        data_available_time: Instant,
841
842        /// Whether or not the data will be taken once served
843        take_data: bool,
844        /// Whether or not the data has been taken
845        taken: Arc<AtomicBool>,
846    }
847
848    #[async_trait]
849    impl DataSource<TestRequest> for TestDataSource {
850        async fn derive_response_for(&self, request: &TestRequest) -> Result<Vec<u8>> {
851            // Return a response if we hit the hit rate
852            if self.has_data && Instant::now() >= self.data_available_time {
853                if self.take_data && !self.taken.swap(true, std::sync::atomic::Ordering::Relaxed) {
854                    return Err(anyhow::anyhow!("data already taken"));
855                }
856                Ok(blake3::hash(&request.0).as_bytes().to_vec())
857            } else {
858                Err(anyhow::anyhow!("did not have the data"))
859            }
860        }
861    }
862
863    /// Create and return a default protocol configuration
864    fn default_protocol_config() -> RequestResponseConfig {
865        RequestResponseConfig {
866            incoming_request_ttl: Duration::from_secs(40),
867            incoming_request_timeout: Duration::from_secs(40),
868            request_batch_size: 10,
869            request_batch_interval: Duration::from_millis(100),
870            max_incoming_requests: 10,
871            max_incoming_requests_per_key: 1,
872            incoming_response_timeout: Duration::from_secs(1),
873            max_incoming_responses: 5,
874        }
875    }
876
877    /// Create fully connected test networks with `num_participants` participants
878    fn create_participants(
879        num: usize,
880    ) -> Vec<(TestSender, mpsc::Receiver<Bytes>, (BLSPubKey, BLSPrivKey))> {
881        // The entire network
882        let mut network = HashMap::new();
883
884        // All receivers in the network
885        let mut receivers = Vec::new();
886
887        // All keypairs in the network
888        let mut keypairs = Vec::new();
889
890        // For each participant,
891        for i in 0..num {
892            // Create a unique `BLSPubKey`
893            let (public_key, private_key) =
894                BLSPubKey::generated_from_seed_indexed([2; 32], i.try_into().unwrap());
895
896            // Add the keypair to the list
897            keypairs.push((public_key, private_key));
898
899            // Create a channel for sending and receiving messages
900            let (sender, receiver) = mpsc::channel::<Bytes>(100);
901
902            // Add the participant to the network
903            network.insert(public_key, sender);
904
905            // Add the receiver to the list of receivers
906            receivers.push(receiver);
907        }
908
909        // Create a test sender from the network
910        let sender = TestSender {
911            network: Arc::new(network),
912        };
913
914        // Return all senders and receivers
915        receivers
916            .into_iter()
917            .zip(keypairs)
918            .map(|(r, k)| (sender.clone(), r, k))
919            .collect()
920    }
921
922    /// The configuration for an integration test
923    #[derive(Clone)]
924    struct IntegrationTestConfig {
925        /// The request response protocol configuration
926        request_response_config: RequestResponseConfig,
927        /// The number of participants in the network
928        num_participants: usize,
929        /// The number of participants that have the data
930        num_participants_with_data: usize,
931        /// The timeout for the requests
932        request_timeout: Duration,
933        /// The delay before the nodes have the data available
934        data_available_delay: Duration,
935    }
936
937    /// The result of an integration test
938    struct IntegrationTestResult {
939        /// The number of nodes that received a response
940        num_succeeded: usize,
941    }
942
943    /// Run an integration test with the given parameters
944    async fn run_integration_test(config: IntegrationTestConfig) -> IntegrationTestResult {
945        // Create a fully connected network with `num_participants` participants
946        let participants = create_participants(config.num_participants);
947
948        // Create a join set to wait for all the tasks to finish
949        let mut join_set = JoinSet::new();
950
951        // We need to keep these here so they don't get dropped
952        let handles = Arc::new(Mutex::new(Vec::new()));
953
954        // For each one, create a new [`RequestResponse`] protocol
955        for (i, (sender, receiver, (public_key, private_key))) in
956            participants.into_iter().enumerate()
957        {
958            let config_clone = config.request_response_config.clone();
959            let handles_clone = Arc::clone(&handles);
960            join_set.spawn(async move {
961                let protocol = RequestResponse::new(
962                    config_clone,
963                    sender.clone(),
964                    receiver,
965                    sender,
966                    TestDataSource {
967                        has_data: i < config.num_participants_with_data,
968                        data_available_time: Instant::now() + config.data_available_delay,
969                        take_data: false,
970                        taken: Arc::new(AtomicBool::new(false)),
971                    },
972                );
973
974                // Add the handle to the handles list so it doesn't get dropped and
975                // cancelled
976                #[allow(clippy::used_underscore_binding)]
977                handles_clone
978                    .lock()
979                    .unwrap()
980                    .push(Arc::clone(&protocol._receiving_task_handle));
981
982                // Create a random request
983                let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
984
985                // Get the hash of the request
986                let request_hash = blake3::hash(&request.0).as_bytes().to_vec();
987
988                // Create a new request message
989                let request = RequestMessage::new_signed(&public_key, &private_key, &request)
990                    .expect("failed to create request message");
991
992                // Request the data from the protocol
993                let response = protocol
994                    .request(
995                        request,
996                        RequestType::Batched,
997                        config.request_timeout,
998                        |_request, response| async move { Ok(response) },
999                    )
1000                    .await?;
1001
1002                // Make sure the response is the hash of the request
1003                assert_eq!(response, request_hash);
1004
1005                Ok::<(), anyhow::Error>(())
1006            });
1007        }
1008
1009        // Wait for all the tasks to finish
1010        let mut num_succeeded = config.num_participants;
1011        while let Some(result) = join_set.join_next().await {
1012            if result.is_err() || result.unwrap().is_err() {
1013                num_succeeded -= 1;
1014            }
1015        }
1016
1017        IntegrationTestResult { num_succeeded }
1018    }
1019
1020    /// Test the integration of the protocol with 50% of the participants having the data
1021    #[tokio::test(flavor = "multi_thread")]
1022    async fn test_integration_50_0s() {
1023        // Build a config
1024        let config = IntegrationTestConfig {
1025            request_response_config: default_protocol_config(),
1026            num_participants: 100,
1027            num_participants_with_data: 50,
1028            request_timeout: Duration::from_secs(40),
1029            data_available_delay: Duration::from_secs(0),
1030        };
1031
1032        // Run the test, making sure all the requests succeed
1033        let result = run_integration_test(config).await;
1034        assert_eq!(result.num_succeeded, 100);
1035    }
1036
1037    /// Test the integration of the protocol when nobody has the data. Make sure we don't
1038    /// get any responses
1039    #[tokio::test(flavor = "multi_thread")]
1040    async fn test_integration_0() {
1041        // Build a config
1042        let config = IntegrationTestConfig {
1043            request_response_config: default_protocol_config(),
1044            num_participants: 100,
1045            num_participants_with_data: 0,
1046            request_timeout: Duration::from_secs(40),
1047            data_available_delay: Duration::from_secs(0),
1048        };
1049
1050        // Run the test
1051        let result = run_integration_test(config).await;
1052
1053        // Make sure all the requests succeeded
1054        assert_eq!(result.num_succeeded, 0);
1055    }
1056
1057    /// Test the integration of the protocol when one node has the data after
1058    /// a delay of 1s
1059    #[tokio::test(flavor = "multi_thread")]
1060    async fn test_integration_1_1s() {
1061        // Build a config
1062        let config = IntegrationTestConfig {
1063            request_response_config: default_protocol_config(),
1064            num_participants: 100,
1065            num_participants_with_data: 1,
1066            request_timeout: Duration::from_secs(40),
1067            data_available_delay: Duration::from_secs(2),
1068        };
1069
1070        // Run the test
1071        let result = run_integration_test(config).await;
1072
1073        // Make sure all the requests succeeded
1074        assert_eq!(result.num_succeeded, 100);
1075    }
1076
1077    /// Test that we can join an existing request for the same data and get the same (single) response
1078    #[tokio::test(flavor = "multi_thread")]
1079    async fn test_join_existing_request() {
1080        // Build a config
1081        let config = default_protocol_config();
1082
1083        // Create two participants
1084        let mut participants = Vec::new();
1085
1086        for (sender, receiver, (public_key, private_key)) in create_participants(2) {
1087            // For each, create a new [`RequestResponse`] protocol
1088            let protocol = RequestResponse::new(
1089                config.clone(),
1090                sender.clone(),
1091                receiver,
1092                sender,
1093                TestDataSource {
1094                    take_data: true,
1095                    has_data: true,
1096                    data_available_time: Instant::now() + Duration::from_secs(2),
1097                    taken: Arc::new(AtomicBool::new(false)),
1098                },
1099            );
1100
1101            // Add the participants to the list
1102            participants.push((protocol, public_key, private_key));
1103        }
1104
1105        // Take the first participant
1106        let one = Arc::new(participants.remove(0));
1107
1108        // Create the request that they should all be able to join on
1109        let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
1110
1111        // Create a join set to wait for all the tasks to finish
1112        let mut join_set = JoinSet::new();
1113
1114        // Make 10 requests with the same hash
1115        for _ in 0..10 {
1116            // Clone the first participant
1117            let one_clone = Arc::clone(&one);
1118
1119            // Clone the request
1120            let request_clone = request.clone();
1121
1122            // Spawn a task to request the data
1123            join_set.spawn(async move {
1124                // Create a new, signed request message
1125                let request_message =
1126                    RequestMessage::new_signed(&one_clone.1, &one_clone.2, &request_clone)?;
1127
1128                // Start requesting it
1129                one_clone
1130                    .0
1131                    .request(
1132                        request_message,
1133                        RequestType::Batched,
1134                        Duration::from_secs(20),
1135                        |_request, response| async move { Ok(response) },
1136                    )
1137                    .await?;
1138
1139                Ok::<(), anyhow::Error>(())
1140            });
1141        }
1142
1143        // Wait for all the tasks to finish, making sure they all succeed
1144        while let Some(result) = join_set.join_next().await {
1145            result
1146                .expect("failed to join task")
1147                .expect("failed to request data");
1148        }
1149    }
1150}