request_response/lib.rs
1//! This crate contains a general request-response protocol. It is used to send requests to
2//! a set of recipients and wait for responses.
3
4use std::{
5 any::Any,
6 collections::HashMap,
7 future::Future,
8 marker::PhantomData,
9 pin::Pin,
10 sync::{Arc, Weak},
11 time::{Duration, Instant},
12};
13
14use anyhow::{Context, Result, anyhow};
15use async_lock::RwLock;
16use data_source::DataSource;
17use derive_more::derive::Deref;
18use hotshot_types::traits::signature_key::SignatureKey;
19use message::{Message, RequestMessage, ResponseMessage};
20use network::{Bytes, Receiver, Sender};
21use rand::seq::SliceRandom;
22use recipient_source::RecipientSource;
23use request::Request;
24use tokio::{
25 spawn,
26 time::{sleep, timeout},
27};
28use tokio_util::task::AbortOnDropHandle;
29use tracing::{debug, error, info, trace, warn};
30use util::{BoundedVecDeque, NamedSemaphore, NamedSemaphoreError};
31
32use crate::util::NamedSemaphorePermit;
33
34/// The data source trait. Is what we use to derive the response data for a request
35pub mod data_source;
36/// The message type. Is the base type for all messages in the request-response protocol
37pub mod message;
38/// The network traits. Is what we use to send and receive messages over the network as
39/// the protocol
40pub mod network;
41/// The recipient source trait. Is what we use to get the recipients that a specific message should
42/// expect responses from
43pub mod recipient_source;
44/// The request trait. Is what we use to define a request and a corresponding response type
45pub mod request;
46/// Utility types and functions
47mod util;
48
49/// A type alias for the hash of a request
50pub type RequestHash = blake3::Hash;
51
52/// A type alias for the outgoing requests map
53pub type OutgoingRequestsMap<Req> =
54 Arc<RwLock<HashMap<RequestHash, Weak<OutgoingRequestInner<Req>>>>>;
55
56/// A type alias for the list of tasks that are responding to requests
57pub type IncomingRequests<K> = NamedSemaphore<K>;
58
59/// A type alias for the list of tasks that are validating incoming responses
60pub type IncomingResponses = NamedSemaphore<RequestHash>;
61
62/// The type of request to make
63#[derive(PartialEq, Eq, Clone, Copy)]
64pub enum RequestType {
65 /// A request that can be satisfied by a single participant,
66 /// and as such will be batched to a few participants at a time
67 /// until one succeeds
68 Batched,
69 /// A request that needs most or all participants to respond,
70 /// and as such will be broadcasted to all participants
71 Broadcast,
72}
73
74/// The errors that can occur when making a request for data
75#[derive(thiserror::Error, Debug)]
76pub enum RequestError {
77 /// The request timed out
78 #[error("request timed out")]
79 Timeout,
80 /// The request was invalid
81 #[error("request was invalid")]
82 InvalidRequest(anyhow::Error),
83 /// Other errors
84 #[error("other error")]
85 Other(anyhow::Error),
86}
87
88/// A trait for serializing and deserializing a type to and from a byte array. [`Request`] types and
89/// [`Response`] types will need to implement this trait
90pub trait Serializable: Sized {
91 /// Serialize the type to a byte array. If this is for a [`Request`] and your [`Request`] type
92 /// is represented as an enum, please make sure that you serialize it with a unique type ID. Otherwise,
93 /// you may end up with collisions as the request hash is used as a unique identifier
94 ///
95 /// # Errors
96 /// - If the type cannot be serialized to a byte array
97 fn to_bytes(&self) -> Result<Vec<u8>>;
98
99 /// Deserialize the type from a byte array
100 ///
101 /// # Errors
102 /// - If the byte array is not a valid representation of the type
103 fn from_bytes(bytes: &[u8]) -> Result<Self>;
104}
105
106/// The underlying configuration for the request-response protocol
107#[derive(Clone)]
108pub struct RequestResponseConfig {
109 /// The timeout for incoming requests. Do not respond to a request after this threshold
110 /// has passed.
111 pub incoming_request_ttl: Duration,
112 /// The maximum amount of time we will spend trying to both derive a response for a request and
113 /// send the response over the wire.
114 pub incoming_request_timeout: Duration,
115 /// The maximum amount of time we will spend trying to validate a response. This is used to prevent
116 /// an attack where a malicious participant sends us a bunch of requests that take a long time to
117 /// validate.
118 pub incoming_response_timeout: Duration,
119 /// The batch size for outgoing requests. This is the number of request messages that we will
120 /// send out at a time for a single request before waiting for the [`request_batch_interval`].
121 pub request_batch_size: usize,
122 /// The time to wait (per request) between sending out batches of request messages
123 pub request_batch_interval: Duration,
124 /// The maximum (global) number of incoming requests that can be processed at any given time.
125 pub max_incoming_requests: usize,
126 /// The maximum number of incoming requests that can be processed for a single key at any given time.
127 pub max_incoming_requests_per_key: usize,
128 /// The maximum (global) number of incoming responses that can be processed at any given time.
129 /// We need this because responses coming in need to be validated [asynchronously] that they
130 /// satisfy the request they are responding to
131 pub max_incoming_responses: usize,
132}
133
134/// A protocol that allows for request-response communication. Is cheaply cloneable, so there is no
135/// need to wrap it in an `Arc`
136#[derive(Deref)]
137pub struct RequestResponse<
138 S: Sender<K>,
139 R: Receiver,
140 Req: Request,
141 RS: RecipientSource<Req, K>,
142 DS: DataSource<Req>,
143 K: SignatureKey + 'static,
144> {
145 #[deref]
146 /// The inner implementation of the request-response protocol
147 pub inner: Arc<RequestResponseInner<S, R, Req, RS, DS, K>>,
148 /// A handle to the receiving task. This will automatically get cancelled when the protocol is dropped
149 _receiving_task_handle: Arc<AbortOnDropHandle<()>>,
150}
151
152/// We need to manually implement the `Clone` trait for this type because deriving
153/// `Deref` will cause an issue where it tries to clone the inner field instead
154impl<
155 S: Sender<K>,
156 R: Receiver,
157 Req: Request,
158 RS: RecipientSource<Req, K>,
159 DS: DataSource<Req>,
160 K: SignatureKey + 'static,
161> Clone for RequestResponse<S, R, Req, RS, DS, K>
162{
163 fn clone(&self) -> Self {
164 Self {
165 inner: Arc::clone(&self.inner),
166 _receiving_task_handle: Arc::clone(&self._receiving_task_handle),
167 }
168 }
169}
170
171impl<
172 S: Sender<K>,
173 R: Receiver,
174 Req: Request,
175 RS: RecipientSource<Req, K>,
176 DS: DataSource<Req>,
177 K: SignatureKey + 'static,
178> RequestResponse<S, R, Req, RS, DS, K>
179{
180 /// Create a new [`RequestResponseProtocol`]
181 pub fn new(
182 // The configuration for the protocol
183 config: RequestResponseConfig,
184 // The network sender that [`RequestResponseProtocol`] will use to send messages
185 sender: S,
186 // The network receiver that [`RequestResponseProtocol`] will use to receive messages
187 receiver: R,
188 // The recipient source that [`RequestResponseProtocol`] will use to get the recipients
189 // that a specific message should expect responses from
190 recipient_source: RS,
191 // The [response] data source that [`RequestResponseProtocol`] will use to derive the
192 // response data for a specific request
193 data_source: DS,
194 ) -> Self {
195 // Create the outgoing requests map
196 let outgoing_requests = OutgoingRequestsMap::default();
197
198 // Create the inner implementation
199 let inner = Arc::new(RequestResponseInner {
200 config,
201 sender,
202 recipient_source,
203 data_source,
204 outgoing_requests,
205 phantom_data: PhantomData,
206 });
207
208 // Start the task that receives messages and handles them. This will automatically get cancelled
209 // when the protocol is dropped
210 let inner_clone = Arc::clone(&inner);
211 let receive_task_handle =
212 AbortOnDropHandle::new(tokio::spawn(inner_clone.receiving_task(receiver)));
213
214 // Return the protocol
215 Self {
216 inner,
217 _receiving_task_handle: Arc::new(receive_task_handle),
218 }
219 }
220}
221
222/// A type alias for an `Arc<dyn Any + Send + Sync + 'static>`
223type ThreadSafeAny = Arc<dyn Any + Send + Sync + 'static>;
224
225/// A type alias for the future that validates a response
226type ResponseValidationFuture =
227 Pin<Box<dyn Future<Output = Result<ThreadSafeAny, anyhow::Error>> + Send + Sync + 'static>>;
228
229/// A type alias for the function that returns the above future
230type ResponseValidationFn<R> =
231 Box<dyn Fn(&R, <R as Request>::Response) -> ResponseValidationFuture + Send + Sync + 'static>;
232
233/// The inner implementation for the request-response protocol
234pub struct RequestResponseInner<
235 S: Sender<K>,
236 R: Receiver,
237 Req: Request,
238 RS: RecipientSource<Req, K>,
239 DS: DataSource<Req>,
240 K: SignatureKey + 'static,
241> {
242 /// The configuration of the protocol
243 config: RequestResponseConfig,
244 /// The sender to use for the protocol
245 pub sender: S,
246 /// The recipient source to use for the protocol
247 pub recipient_source: RS,
248 /// The data source to use for the protocol
249 data_source: DS,
250 /// The map of currently active, outgoing requests
251 outgoing_requests: OutgoingRequestsMap<Req>,
252 /// Phantom data to help with type inference
253 phantom_data: PhantomData<(K, R, Req, DS)>,
254}
255impl<
256 S: Sender<K>,
257 R: Receiver,
258 Req: Request,
259 RS: RecipientSource<Req, K>,
260 DS: DataSource<Req>,
261 K: SignatureKey + 'static,
262> RequestResponseInner<S, R, Req, RS, DS, K>
263{
264 /// Request something from the protocol indefinitely until we get a response
265 /// or there was a critical error (e.g. the request could not be signed)
266 ///
267 /// # Errors
268 /// - If the request was invalid
269 /// - If there was a critical error (e.g. the channel was closed)
270 pub async fn request_indefinitely<F, Fut, O>(
271 self: &Arc<Self>,
272 public_key: &K,
273 private_key: &K::PrivateKey,
274 // The type of request to make
275 request_type: RequestType,
276 // The estimated TTL of other participants. This is used to decide when to
277 // stop making requests and sign a new one
278 estimated_request_ttl: Duration,
279 // The request to make
280 request: Req,
281 // The response validation function
282 response_validation_fn: F,
283 ) -> std::result::Result<O, RequestError>
284 where
285 F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
286 Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
287 O: Send + Sync + 'static + Clone,
288 {
289 loop {
290 // Sign a request message
291 let request_message = RequestMessage::new_signed(public_key, private_key, &request)
292 .map_err(|e| {
293 RequestError::InvalidRequest(anyhow::anyhow!(
294 "failed to sign request message: {e}"
295 ))
296 })?;
297
298 // Request the data, handling the errors appropriately
299 match self
300 .request(
301 request_message,
302 request_type,
303 estimated_request_ttl,
304 response_validation_fn.clone(),
305 )
306 .await
307 {
308 Ok(response) => return Ok(response),
309 Err(RequestError::Timeout) => continue,
310 Err(e) => return Err(e),
311 }
312 }
313 }
314
315 /// Request something from the protocol and wait for the response. This function
316 /// will join with an existing request for the same data (determined by `Blake3` hash),
317 /// however both will make requests until the timeout is reached
318 ///
319 /// # Errors
320 /// - If the request times out
321 /// - If the channel is closed (this is an internal error)
322 /// - If the request we sign is invalid
323 pub async fn request<F, Fut, O>(
324 self: &Arc<Self>,
325 request_message: RequestMessage<Req, K>,
326 request_type: RequestType,
327 timeout_duration: Duration,
328 response_validation_fn: F,
329 ) -> std::result::Result<O, RequestError>
330 where
331 F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
332 Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
333 O: Send + Sync + 'static + Clone,
334 {
335 timeout(timeout_duration, async move {
336 // Calculate the hash of the request
337 let request_hash = blake3::hash(&request_message.request.to_bytes().map_err(|e| {
338 RequestError::InvalidRequest(anyhow::anyhow!(
339 "failed to serialize request message: {e}"
340 ))
341 })?);
342
343 let request = {
344 // Get a write lock on the outgoing requests map
345 let mut outgoing_requests_write = self.outgoing_requests.write().await;
346
347 // Conditionally get the outgoing request, creating a new one if it doesn't exist or if
348 // the existing one has been dropped and not yet removed
349 if let Some(outgoing_request) = outgoing_requests_write
350 .get(&request_hash)
351 .and_then(Weak::upgrade)
352 {
353 OutgoingRequest(outgoing_request)
354 } else {
355 // Create a new broadcast channel for the response
356 let (sender, receiver) = async_broadcast::broadcast(1);
357
358 // Modify the response validation function to return an `Arc<dyn Any>`
359 let response_validation_fn =
360 Box::new(move |request: &Req, response: Req::Response| {
361 let fut = response_validation_fn(request, response);
362 Box::pin(
363 async move { fut.await.map(|ok| Arc::new(ok) as ThreadSafeAny) },
364 ) as ResponseValidationFuture
365 });
366
367 // Create a new outgoing request
368 let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
369 sender,
370 receiver,
371 response_validation_fn,
372 request: request_message.request.clone(),
373 }));
374
375 // Write the new outgoing request to the map
376 outgoing_requests_write
377 .insert(request_hash, Arc::downgrade(&outgoing_request.0));
378
379 // Return the new outgoing request
380 outgoing_request
381 }
382 };
383
384 // Create a request message and serialize it
385 let message = Bytes::from(
386 Message::Request(request_message.clone())
387 .to_bytes()
388 .map_err(|e| {
389 RequestError::InvalidRequest(anyhow::anyhow!(
390 "failed to serialize request message: {e}"
391 ))
392 })?,
393 );
394
395 // Create a place to put the handle for the batched sending task. We need this because
396 // otherwise it gets dropped when the closure goes out of scope, instead of when the function
397 // gets cancelled or returns
398 let mut _batched_sending_task = None;
399
400 // Match on the type of request
401 if request_type == RequestType::Broadcast {
402 trace!("Sending request {request_message:?} to all participants");
403
404 // If the message is a broadcast request, just send it to all participants
405 self.sender
406 .send_broadcast_message(&message)
407 .await
408 .map_err(|e| {
409 RequestError::Other(anyhow::anyhow!(
410 "failed to send broadcast message: {e}"
411 ))
412 })?;
413 } else {
414 // If the message is a batched request, we need to batch it with other requests
415
416 // Get the recipients that the request should expect responses from. Shuffle them so
417 // that we don't always send to the same recipients in the same order
418 let mut recipients = self
419 .recipient_source
420 .get_expected_responders(&request_message.request)
421 .await
422 .map_err(|e| {
423 RequestError::InvalidRequest(anyhow::anyhow!(
424 "failed to get expected responders for request: {e}"
425 ))
426 })?;
427 recipients.shuffle(&mut rand::thread_rng());
428
429 // Get the current time so we can check when the timeout has elapsed
430 let start_time = Instant::now();
431
432 // Spawn a task that sends out requests to the network
433 let self_clone = Arc::clone(self);
434 let batched_sending_handle = AbortOnDropHandle::new(spawn(async move {
435 // Create a bounded queue for the outgoing requests. We use this to make sure
436 // we have less than [`config.request_batch_size`] requests in flight at any time.
437 //
438 // When newer requests are added, older ones are removed from the queue. Because we use
439 // `AbortOnDropHandle`, the older ones will automatically get cancelled
440 let mut outgoing_requests =
441 BoundedVecDeque::new(self_clone.config.request_batch_size);
442
443 // While the timeout hasn't elapsed, send out requests to the network
444 while start_time.elapsed() < timeout_duration {
445 // Send out requests to the network in their own separate tasks
446 for recipient_batch in
447 recipients.chunks(self_clone.config.request_batch_size)
448 {
449 for recipient in recipient_batch {
450 // Clone ourselves, the message, and the recipient so they can be moved
451 let self_clone = Arc::clone(&self_clone);
452 let request_message_clone = request_message.clone();
453 let recipient_clone = recipient.clone();
454 let message_clone = Arc::clone(&message);
455
456 // Spawn the task that sends the request to the participant
457 let individual_sending_task = spawn(async move {
458 trace!(
459 "Sending request {request_message_clone:?} to \
460 {recipient_clone:?}"
461 );
462
463 let _ = self_clone
464 .sender
465 .send_direct_message(&message_clone, recipient_clone)
466 .await;
467 });
468
469 // Add the sending task to the queue
470 outgoing_requests
471 .push(AbortOnDropHandle::new(individual_sending_task));
472 }
473
474 // After we send the batch out, wait the [`config.request_batch_interval`]
475 // before sending the next one
476 sleep(self_clone.config.request_batch_interval).await;
477 }
478 }
479 }));
480
481 // Store the handle so it doesn't get dropped
482 _batched_sending_task = Some(batched_sending_handle);
483 }
484
485 // Wait for a response on the channel
486 request
487 .receiver
488 .clone()
489 .recv()
490 .await
491 .map_err(|_| RequestError::Other(anyhow!("channel was closed")))
492 })
493 .await
494 .map_err(|_| RequestError::Timeout)
495 .and_then(|result| result)
496 .and_then(|result| {
497 result.downcast::<O>().map_err(|e| {
498 RequestError::Other(anyhow::anyhow!(
499 "failed to downcast response to expected type: {e:?}"
500 ))
501 })
502 })
503 .map(|result| Arc::unwrap_or_clone(result))
504 }
505
506 /// The task responsible for receiving messages from the receiver and handling them
507 async fn receiving_task(self: Arc<Self>, mut receiver: R) {
508 // Upper bound the number of outgoing and incoming responses
509 let mut incoming_requests = NamedSemaphore::new(
510 self.config.max_incoming_requests_per_key,
511 Some(self.config.max_incoming_requests),
512 );
513 // process 1 response per key, with a global maximum of [`config.max_incoming_responses`] responses
514 let mut incoming_responses =
515 NamedSemaphore::new(1, Some(self.config.max_incoming_responses));
516
517 // While the receiver is open, we receive messages and handle them
518 loop {
519 // Try to receive a message
520 match receiver.receive_message().await {
521 Ok(message) => {
522 // Deserialize the message, warning if it fails
523 let message = match Message::from_bytes(&message) {
524 Ok(message) => message,
525 Err(e) => {
526 warn!("Received invalid message: {e:#}");
527 continue;
528 },
529 };
530
531 // Handle the message based on its type
532 match message {
533 Message::Request(request_message) => {
534 self.handle_request(request_message, &mut incoming_requests);
535 },
536 Message::Response(response_message) => {
537 self.handle_response(response_message, &mut incoming_responses);
538 },
539 }
540 },
541 // An error here means the receiver will _NEVER_ receive any more messages
542 Err(e) => {
543 error!("Request/response receive task exited: {e:#}");
544 return;
545 },
546 }
547 }
548 }
549
550 /// Handle a request sent to us
551 fn handle_request(
552 self: &Arc<Self>,
553 request_message: RequestMessage<Req, K>,
554 incoming_requests: &mut IncomingRequests<K>,
555 ) {
556 trace!("Handling request {:?}", request_message);
557
558 // Spawn a task to:
559 // - Validate the request
560 // - Derive the response data (check if we have it)
561 // - Send the response to the requester
562 let self_clone = Arc::clone(self);
563
564 // Attempt to acquire a permit for the request. Warn if there are too many requests currently being processed
565 // either globally or per-key
566 let permit = incoming_requests.try_acquire(request_message.public_key.clone());
567 match permit {
568 Ok(ref permit) => permit,
569 Err(NamedSemaphoreError::PerKeyLimitReached) => {
570 info!(
571 "Failed to process request from {}: too many requests from the same key are \
572 already being processed",
573 request_message.public_key
574 );
575 return;
576 },
577 Err(NamedSemaphoreError::GlobalLimitReached) => {
578 info!(
579 "Failed to process request from {}: too many requests are already being \
580 processed",
581 request_message.public_key
582 );
583 return;
584 },
585 };
586
587 tokio::spawn(async move {
588 let result = timeout(self_clone.config.incoming_request_timeout, async move {
589 // Validate the request message. This includes:
590 // - Checking the signature and making sure it's valid
591 // - Checking the timestamp and making sure it's not too old
592 // - Calling the request's application-specific validation function
593 request_message
594 .validate(self_clone.config.incoming_request_ttl)
595 .with_context(|| "failed to validate request")?;
596
597 // Try to fetch the response data from the data source
598 let response = self_clone
599 .data_source
600 .derive_response_for(&request_message.request)
601 .await
602 .with_context(|| "failed to derive response for request")?;
603
604 // Create the response message and serialize it
605 let response = Bytes::from(
606 Message::Response::<Req, K>(ResponseMessage {
607 request_hash: blake3::hash(&request_message.request.to_bytes()?),
608 response,
609 })
610 .to_bytes()
611 .with_context(|| "failed to serialize response message")?,
612 );
613
614 // Send the response to the requester
615 self_clone
616 .sender
617 .send_direct_message(&response, request_message.public_key)
618 .await
619 .with_context(|| "failed to send response to requester")?;
620
621 // Drop the permit
622 _ = permit;
623 drop(permit);
624
625 Ok::<(), anyhow::Error>(())
626 })
627 .await
628 .map_err(|_| anyhow::anyhow!("timed out while sending response"))
629 .and_then(|result| result);
630
631 if let Err(e) = result {
632 debug!("Failed to send response to requester: {e:#}");
633 }
634 });
635 }
636
637 async fn wait_for_permit(
638 incoming_responses: &mut IncomingResponses,
639 request_hash: RequestHash,
640 outgoing_requests: &OutgoingRequestsMap<Req>,
641 ) -> Result<NamedSemaphorePermit<RequestHash>, NamedSemaphoreError> {
642 while outgoing_requests.read().await.contains_key(&request_hash) {
643 let permit = incoming_responses.try_acquire(request_hash);
644 let permit = match permit {
645 Ok(permit) => permit,
646 Err(NamedSemaphoreError::PerKeyLimitReached) => {
647 tokio::time::sleep(Duration::from_millis(100)).await;
648 continue;
649 },
650 Err(NamedSemaphoreError::GlobalLimitReached) => {
651 warn!(
652 "Failed to process response: too many responses are already being \
653 processed"
654 );
655 return Err(NamedSemaphoreError::GlobalLimitReached);
656 },
657 };
658 return Ok(permit);
659 }
660 Err(NamedSemaphoreError::GlobalLimitReached)
661 }
662
663 /// Handle a response sent to us
664 fn handle_response(
665 self: &Arc<Self>,
666 response: ResponseMessage<Req>,
667 incoming_responses: &mut IncomingResponses,
668 ) {
669 trace!("Handling response {response:?}");
670
671 let outgoing_requests_clone = self.outgoing_requests.clone();
672 let mut incoming_responses_clone = incoming_responses.clone();
673
674 // Spawn a task to validate the response and send it to the requester (us)
675 let response_validate_timeout = self.config.incoming_response_timeout;
676 tokio::spawn(async move {
677 if timeout(response_validate_timeout, async move {
678 // Attempt to acquire a permit for the request. Warn if there are too many responses currently being processed
679 let permit = Self::wait_for_permit(
680 &mut incoming_responses_clone,
681 response.request_hash,
682 &outgoing_requests_clone,
683 )
684 .await;
685
686 let Ok(permit) = permit else {
687 warn!(
688 "Failed to process response: too many responses are already being \
689 processed"
690 );
691 return;
692 };
693 // Get the entry in the map, ignoring it if it doesn't exist
694 let Some(outgoing_request) = outgoing_requests_clone
695 .read()
696 .await
697 .get(&response.request_hash)
698 .cloned()
699 .and_then(|r| r.upgrade())
700 else {
701 return;
702 };
703 // Make sure the response is valid for the given request
704 let validation_result = match (outgoing_request.response_validation_fn)(
705 &outgoing_request.request,
706 response.response,
707 )
708 .await
709 {
710 Ok(validation_result) => validation_result,
711 Err(e) => {
712 debug!("Received invalid response: {e:#}");
713 return;
714 },
715 };
716
717 // Send the response to the requester (the user of [`RequestResponse::request`])
718 let _ = outgoing_request.sender.try_broadcast(validation_result);
719
720 outgoing_requests_clone
721 .write()
722 .await
723 .remove(&response.request_hash);
724
725 // Drop the permit
726 _ = permit;
727 drop(permit);
728 })
729 .await
730 .is_err()
731 {
732 warn!("Timed out while validating response");
733 }
734 });
735 }
736}
737
738/// An outgoing request. This is what we use to track a request and its corresponding response
739/// in the protocol
740#[derive(Clone, Deref)]
741pub struct OutgoingRequest<R: Request>(Arc<OutgoingRequestInner<R>>);
742
743/// The inner implementation of an outgoing request
744pub struct OutgoingRequestInner<R: Request> {
745 /// The sender to use for the protocol
746 sender: async_broadcast::Sender<ThreadSafeAny>,
747 /// The receiver to use for the protocol
748 receiver: async_broadcast::Receiver<ThreadSafeAny>,
749
750 /// The request that we are waiting for a response to
751 request: R,
752
753 /// The function used to validate the response
754 response_validation_fn: ResponseValidationFn<R>,
755}
756
757#[cfg(test)]
758mod tests {
759 use std::{
760 collections::HashMap,
761 sync::{Mutex, atomic::AtomicBool},
762 };
763
764 use async_trait::async_trait;
765 use hotshot_types::signature_key::{BLSPrivKey, BLSPubKey};
766 use rand::Rng;
767 use tokio::{sync::mpsc, task::JoinSet};
768
769 use super::*;
770
771 /// A test sender that has a list of all the participants in the network
772 #[derive(Clone)]
773 pub struct TestSender {
774 network: Arc<HashMap<BLSPubKey, mpsc::Sender<Bytes>>>,
775 }
776
777 /// An implementation of the [`Sender`] trait for the [`TestSender`] type
778 #[async_trait]
779 impl Sender<BLSPubKey> for TestSender {
780 async fn send_direct_message(&self, message: &Bytes, recipient: BLSPubKey) -> Result<()> {
781 self.network
782 .get(&recipient)
783 .ok_or(anyhow::anyhow!("recipient not found"))?
784 .send(Arc::clone(message))
785 .await
786 .map_err(|_| anyhow::anyhow!("failed to send message"))?;
787
788 Ok(())
789 }
790
791 async fn send_broadcast_message(&self, message: &Bytes) -> Result<()> {
792 for sender in self.network.values() {
793 sender
794 .send(Arc::clone(message))
795 .await
796 .map_err(|_| anyhow::anyhow!("failed to send message"))?;
797 }
798 Ok(())
799 }
800 }
801
802 // Implement the [`RecipientSource`] trait for the [`TestSender`] type
803 #[async_trait]
804 impl RecipientSource<TestRequest, BLSPubKey> for TestSender {
805 async fn get_expected_responders(&self, _request: &TestRequest) -> Result<Vec<BLSPubKey>> {
806 // Get all the participants in the network
807 Ok(self.network.keys().copied().collect())
808 }
809 }
810
811 // Create a test request that is just some bytes
812 #[derive(Clone, Debug)]
813 struct TestRequest(Vec<u8>);
814
815 // Implement the [`Serializable`] trait for the [`TestRequest`] type
816 impl Serializable for TestRequest {
817 fn to_bytes(&self) -> Result<Vec<u8>> {
818 Ok(self.0.clone())
819 }
820
821 fn from_bytes(bytes: &[u8]) -> Result<Self> {
822 Ok(TestRequest(bytes.to_vec()))
823 }
824 }
825
826 // Implement the [`Request`] trait for the [`TestRequest`] type
827 impl Request for TestRequest {
828 type Response = Vec<u8>;
829 fn validate(&self) -> Result<()> {
830 Ok(())
831 }
832 }
833
834 // Create a test data source that pretends to have the data or not
835 #[derive(Clone)]
836 struct TestDataSource {
837 /// Whether we have the data or not
838 has_data: bool,
839 /// The time at which the data will be available if we have it
840 data_available_time: Instant,
841
842 /// Whether or not the data will be taken once served
843 take_data: bool,
844 /// Whether or not the data has been taken
845 taken: Arc<AtomicBool>,
846 }
847
848 #[async_trait]
849 impl DataSource<TestRequest> for TestDataSource {
850 async fn derive_response_for(&self, request: &TestRequest) -> Result<Vec<u8>> {
851 // Return a response if we hit the hit rate
852 if self.has_data && Instant::now() >= self.data_available_time {
853 if self.take_data && !self.taken.swap(true, std::sync::atomic::Ordering::Relaxed) {
854 return Err(anyhow::anyhow!("data already taken"));
855 }
856 Ok(blake3::hash(&request.0).as_bytes().to_vec())
857 } else {
858 Err(anyhow::anyhow!("did not have the data"))
859 }
860 }
861 }
862
863 /// Create and return a default protocol configuration
864 fn default_protocol_config() -> RequestResponseConfig {
865 RequestResponseConfig {
866 incoming_request_ttl: Duration::from_secs(40),
867 incoming_request_timeout: Duration::from_secs(40),
868 request_batch_size: 10,
869 request_batch_interval: Duration::from_millis(100),
870 max_incoming_requests: 10,
871 max_incoming_requests_per_key: 1,
872 incoming_response_timeout: Duration::from_secs(1),
873 max_incoming_responses: 5,
874 }
875 }
876
877 /// Create fully connected test networks with `num_participants` participants
878 fn create_participants(
879 num: usize,
880 ) -> Vec<(TestSender, mpsc::Receiver<Bytes>, (BLSPubKey, BLSPrivKey))> {
881 // The entire network
882 let mut network = HashMap::new();
883
884 // All receivers in the network
885 let mut receivers = Vec::new();
886
887 // All keypairs in the network
888 let mut keypairs = Vec::new();
889
890 // For each participant,
891 for i in 0..num {
892 // Create a unique `BLSPubKey`
893 let (public_key, private_key) =
894 BLSPubKey::generated_from_seed_indexed([2; 32], i.try_into().unwrap());
895
896 // Add the keypair to the list
897 keypairs.push((public_key, private_key));
898
899 // Create a channel for sending and receiving messages
900 let (sender, receiver) = mpsc::channel::<Bytes>(100);
901
902 // Add the participant to the network
903 network.insert(public_key, sender);
904
905 // Add the receiver to the list of receivers
906 receivers.push(receiver);
907 }
908
909 // Create a test sender from the network
910 let sender = TestSender {
911 network: Arc::new(network),
912 };
913
914 // Return all senders and receivers
915 receivers
916 .into_iter()
917 .zip(keypairs)
918 .map(|(r, k)| (sender.clone(), r, k))
919 .collect()
920 }
921
922 /// The configuration for an integration test
923 #[derive(Clone)]
924 struct IntegrationTestConfig {
925 /// The request response protocol configuration
926 request_response_config: RequestResponseConfig,
927 /// The number of participants in the network
928 num_participants: usize,
929 /// The number of participants that have the data
930 num_participants_with_data: usize,
931 /// The timeout for the requests
932 request_timeout: Duration,
933 /// The delay before the nodes have the data available
934 data_available_delay: Duration,
935 }
936
937 /// The result of an integration test
938 struct IntegrationTestResult {
939 /// The number of nodes that received a response
940 num_succeeded: usize,
941 }
942
943 /// Run an integration test with the given parameters
944 async fn run_integration_test(config: IntegrationTestConfig) -> IntegrationTestResult {
945 // Create a fully connected network with `num_participants` participants
946 let participants = create_participants(config.num_participants);
947
948 // Create a join set to wait for all the tasks to finish
949 let mut join_set = JoinSet::new();
950
951 // We need to keep these here so they don't get dropped
952 let handles = Arc::new(Mutex::new(Vec::new()));
953
954 // For each one, create a new [`RequestResponse`] protocol
955 for (i, (sender, receiver, (public_key, private_key))) in
956 participants.into_iter().enumerate()
957 {
958 let config_clone = config.request_response_config.clone();
959 let handles_clone = Arc::clone(&handles);
960 join_set.spawn(async move {
961 let protocol = RequestResponse::new(
962 config_clone,
963 sender.clone(),
964 receiver,
965 sender,
966 TestDataSource {
967 has_data: i < config.num_participants_with_data,
968 data_available_time: Instant::now() + config.data_available_delay,
969 take_data: false,
970 taken: Arc::new(AtomicBool::new(false)),
971 },
972 );
973
974 // Add the handle to the handles list so it doesn't get dropped and
975 // cancelled
976 #[allow(clippy::used_underscore_binding)]
977 handles_clone
978 .lock()
979 .unwrap()
980 .push(Arc::clone(&protocol._receiving_task_handle));
981
982 // Create a random request
983 let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
984
985 // Get the hash of the request
986 let request_hash = blake3::hash(&request.0).as_bytes().to_vec();
987
988 // Create a new request message
989 let request = RequestMessage::new_signed(&public_key, &private_key, &request)
990 .expect("failed to create request message");
991
992 // Request the data from the protocol
993 let response = protocol
994 .request(
995 request,
996 RequestType::Batched,
997 config.request_timeout,
998 |_request, response| async move { Ok(response) },
999 )
1000 .await?;
1001
1002 // Make sure the response is the hash of the request
1003 assert_eq!(response, request_hash);
1004
1005 Ok::<(), anyhow::Error>(())
1006 });
1007 }
1008
1009 // Wait for all the tasks to finish
1010 let mut num_succeeded = config.num_participants;
1011 while let Some(result) = join_set.join_next().await {
1012 if result.is_err() || result.unwrap().is_err() {
1013 num_succeeded -= 1;
1014 }
1015 }
1016
1017 IntegrationTestResult { num_succeeded }
1018 }
1019
1020 /// Test the integration of the protocol with 50% of the participants having the data
1021 #[tokio::test(flavor = "multi_thread")]
1022 async fn test_integration_50_0s() {
1023 // Build a config
1024 let config = IntegrationTestConfig {
1025 request_response_config: default_protocol_config(),
1026 num_participants: 100,
1027 num_participants_with_data: 50,
1028 request_timeout: Duration::from_secs(40),
1029 data_available_delay: Duration::from_secs(0),
1030 };
1031
1032 // Run the test, making sure all the requests succeed
1033 let result = run_integration_test(config).await;
1034 assert_eq!(result.num_succeeded, 100);
1035 }
1036
1037 /// Test the integration of the protocol when nobody has the data. Make sure we don't
1038 /// get any responses
1039 #[tokio::test(flavor = "multi_thread")]
1040 async fn test_integration_0() {
1041 // Build a config
1042 let config = IntegrationTestConfig {
1043 request_response_config: default_protocol_config(),
1044 num_participants: 100,
1045 num_participants_with_data: 0,
1046 request_timeout: Duration::from_secs(40),
1047 data_available_delay: Duration::from_secs(0),
1048 };
1049
1050 // Run the test
1051 let result = run_integration_test(config).await;
1052
1053 // Make sure all the requests succeeded
1054 assert_eq!(result.num_succeeded, 0);
1055 }
1056
1057 /// Test the integration of the protocol when one node has the data after
1058 /// a delay of 1s
1059 #[tokio::test(flavor = "multi_thread")]
1060 async fn test_integration_1_1s() {
1061 // Build a config
1062 let config = IntegrationTestConfig {
1063 request_response_config: default_protocol_config(),
1064 num_participants: 100,
1065 num_participants_with_data: 1,
1066 request_timeout: Duration::from_secs(40),
1067 data_available_delay: Duration::from_secs(2),
1068 };
1069
1070 // Run the test
1071 let result = run_integration_test(config).await;
1072
1073 // Make sure all the requests succeeded
1074 assert_eq!(result.num_succeeded, 100);
1075 }
1076
1077 /// Test that we can join an existing request for the same data and get the same (single) response
1078 #[tokio::test(flavor = "multi_thread")]
1079 async fn test_join_existing_request() {
1080 // Build a config
1081 let config = default_protocol_config();
1082
1083 // Create two participants
1084 let mut participants = Vec::new();
1085
1086 for (sender, receiver, (public_key, private_key)) in create_participants(2) {
1087 // For each, create a new [`RequestResponse`] protocol
1088 let protocol = RequestResponse::new(
1089 config.clone(),
1090 sender.clone(),
1091 receiver,
1092 sender,
1093 TestDataSource {
1094 take_data: true,
1095 has_data: true,
1096 data_available_time: Instant::now() + Duration::from_secs(2),
1097 taken: Arc::new(AtomicBool::new(false)),
1098 },
1099 );
1100
1101 // Add the participants to the list
1102 participants.push((protocol, public_key, private_key));
1103 }
1104
1105 // Take the first participant
1106 let one = Arc::new(participants.remove(0));
1107
1108 // Create the request that they should all be able to join on
1109 let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
1110
1111 // Create a join set to wait for all the tasks to finish
1112 let mut join_set = JoinSet::new();
1113
1114 // Make 10 requests with the same hash
1115 for _ in 0..10 {
1116 // Clone the first participant
1117 let one_clone = Arc::clone(&one);
1118
1119 // Clone the request
1120 let request_clone = request.clone();
1121
1122 // Spawn a task to request the data
1123 join_set.spawn(async move {
1124 // Create a new, signed request message
1125 let request_message =
1126 RequestMessage::new_signed(&one_clone.1, &one_clone.2, &request_clone)?;
1127
1128 // Start requesting it
1129 one_clone
1130 .0
1131 .request(
1132 request_message,
1133 RequestType::Batched,
1134 Duration::from_secs(20),
1135 |_request, response| async move { Ok(response) },
1136 )
1137 .await?;
1138
1139 Ok::<(), anyhow::Error>(())
1140 });
1141 }
1142
1143 // Wait for all the tasks to finish, making sure they all succeed
1144 while let Some(result) = join_set.join_next().await {
1145 result
1146 .expect("failed to join task")
1147 .expect("failed to request data");
1148 }
1149 }
1150}