request_response/lib.rs
1//! This crate contains a general request-response protocol. It is used to send requests to
2//! a set of recipients and wait for responses.
3
4use std::{
5 any::Any,
6 collections::HashMap,
7 future::Future,
8 marker::PhantomData,
9 pin::Pin,
10 sync::{Arc, Weak},
11 time::{Duration, Instant},
12};
13
14use anyhow::{Context, Result, anyhow};
15use async_lock::RwLock;
16use data_source::DataSource;
17use derive_more::derive::Deref;
18use hotshot_types::traits::signature_key::SignatureKey;
19use message::{Message, RequestMessage, ResponseMessage};
20use network::{Bytes, Receiver, Sender};
21use rand::seq::SliceRandom;
22use recipient_source::RecipientSource;
23use request::Request;
24use tokio::{
25 spawn,
26 time::{sleep, timeout},
27};
28use tokio_util::task::AbortOnDropHandle;
29use tracing::{debug, error, info, trace, warn};
30use util::{BoundedVecDeque, NamedSemaphore, NamedSemaphoreError};
31
32use crate::util::NamedSemaphorePermit;
33
34/// The data source trait. Is what we use to derive the response data for a request
35pub mod data_source;
36/// The message type. Is the base type for all messages in the request-response protocol
37pub mod message;
38/// The network traits. Is what we use to send and receive messages over the network as
39/// the protocol
40pub mod network;
41/// The recipient source trait. Is what we use to get the recipients that a specific message should
42/// expect responses from
43pub mod recipient_source;
44/// The request trait. Is what we use to define a request and a corresponding response type
45pub mod request;
46/// Utility types and functions
47mod util;
48
49/// A type alias for the hash of a request
50pub type RequestHash = blake3::Hash;
51
52/// A type alias for the outgoing requests map
53pub type OutgoingRequestsMap<Req> =
54 Arc<RwLock<HashMap<RequestHash, Weak<OutgoingRequestInner<Req>>>>>;
55
56/// A type alias for the list of tasks that are responding to requests
57pub type IncomingRequests<K> = NamedSemaphore<K>;
58
59/// A type alias for the list of tasks that are validating incoming responses
60pub type IncomingResponses = NamedSemaphore<RequestHash>;
61
62/// The type of request to make
63#[derive(PartialEq, Eq, Clone, Copy)]
64pub enum RequestType {
65 /// A request that can be satisfied by a single participant,
66 /// and as such will be batched to a few participants at a time
67 /// until one succeeds
68 Batched,
69 /// A request that needs most or all participants to respond,
70 /// and as such will be broadcasted to all participants
71 Broadcast,
72}
73
74/// The errors that can occur when making a request for data
75#[derive(thiserror::Error, Debug)]
76pub enum RequestError {
77 /// The request timed out
78 #[error("request timed out")]
79 Timeout,
80 /// The request was invalid
81 #[error("request was invalid")]
82 InvalidRequest(anyhow::Error),
83 /// Other errors
84 #[error("other error")]
85 Other(anyhow::Error),
86}
87
88/// A trait for serializing and deserializing a type to and from a byte array. [`Request`] types and
89/// [`Response`] types will need to implement this trait
90pub trait Serializable: Sized {
91 /// Serialize the type to a byte array. If this is for a [`Request`] and your [`Request`] type
92 /// is represented as an enum, please make sure that you serialize it with a unique type ID. Otherwise,
93 /// you may end up with collisions as the request hash is used as a unique identifier
94 ///
95 /// # Errors
96 /// - If the type cannot be serialized to a byte array
97 fn to_bytes(&self) -> Result<Vec<u8>>;
98
99 /// Deserialize the type from a byte array
100 ///
101 /// # Errors
102 /// - If the byte array is not a valid representation of the type
103 fn from_bytes(bytes: &[u8]) -> Result<Self>;
104}
105
106/// The underlying configuration for the request-response protocol
107#[derive(Clone)]
108pub struct RequestResponseConfig {
109 /// The timeout for incoming requests. Do not respond to a request after this threshold
110 /// has passed.
111 pub incoming_request_ttl: Duration,
112 /// The maximum amount of time we will spend trying to both derive a response for a request and
113 /// send the response over the wire.
114 pub incoming_request_timeout: Duration,
115 /// The maximum amount of time we will spend trying to validate a response. This is used to prevent
116 /// an attack where a malicious participant sends us a bunch of requests that take a long time to
117 /// validate.
118 pub incoming_response_timeout: Duration,
119 /// The batch size for outgoing requests. This is the number of request messages that we will
120 /// send out at a time for a single request before waiting for the [`request_batch_interval`].
121 pub request_batch_size: usize,
122 /// The time to wait (per request) between sending out batches of request messages
123 pub request_batch_interval: Duration,
124 /// The maximum (global) number of incoming requests that can be processed at any given time.
125 pub max_incoming_requests: usize,
126 /// The maximum number of incoming requests that can be processed for a single key at any given time.
127 pub max_incoming_requests_per_key: usize,
128 /// The maximum (global) number of incoming responses that can be processed at any given time.
129 /// We need this because responses coming in need to be validated [asynchronously] that they
130 /// satisfy the request they are responding to
131 pub max_incoming_responses: usize,
132}
133
134/// A protocol that allows for request-response communication. Is cheaply cloneable, so there is no
135/// need to wrap it in an `Arc`
136#[derive(Deref)]
137pub struct RequestResponse<
138 S: Sender<K>,
139 R: Receiver,
140 Req: Request,
141 RS: RecipientSource<Req, K>,
142 DS: DataSource<Req>,
143 K: SignatureKey + 'static,
144> {
145 #[deref]
146 /// The inner implementation of the request-response protocol
147 pub inner: Arc<RequestResponseInner<S, R, Req, RS, DS, K>>,
148 /// A handle to the receiving task. This will automatically get cancelled when the protocol is dropped
149 _receiving_task_handle: Arc<AbortOnDropHandle<()>>,
150}
151
152/// We need to manually implement the `Clone` trait for this type because deriving
153/// `Deref` will cause an issue where it tries to clone the inner field instead
154impl<
155 S: Sender<K>,
156 R: Receiver,
157 Req: Request,
158 RS: RecipientSource<Req, K>,
159 DS: DataSource<Req>,
160 K: SignatureKey + 'static,
161> Clone for RequestResponse<S, R, Req, RS, DS, K>
162{
163 fn clone(&self) -> Self {
164 Self {
165 inner: Arc::clone(&self.inner),
166 _receiving_task_handle: Arc::clone(&self._receiving_task_handle),
167 }
168 }
169}
170
171impl<
172 S: Sender<K>,
173 R: Receiver,
174 Req: Request,
175 RS: RecipientSource<Req, K>,
176 DS: DataSource<Req>,
177 K: SignatureKey + 'static,
178> RequestResponse<S, R, Req, RS, DS, K>
179{
180 /// Create a new [`RequestResponseProtocol`]
181 pub fn new(
182 // The configuration for the protocol
183 config: RequestResponseConfig,
184 // The network sender that [`RequestResponseProtocol`] will use to send messages
185 sender: S,
186 // The network receiver that [`RequestResponseProtocol`] will use to receive messages
187 receiver: R,
188 // The recipient source that [`RequestResponseProtocol`] will use to get the recipients
189 // that a specific message should expect responses from
190 recipient_source: RS,
191 // The [response] data source that [`RequestResponseProtocol`] will use to derive the
192 // response data for a specific request
193 data_source: DS,
194 ) -> Self {
195 // Create the outgoing requests map
196 let outgoing_requests = OutgoingRequestsMap::default();
197
198 // Create the inner implementation
199 let inner = Arc::new(RequestResponseInner {
200 config,
201 sender,
202 recipient_source,
203 data_source,
204 outgoing_requests,
205 phantom_data: PhantomData,
206 });
207
208 // Start the task that receives messages and handles them. This will automatically get cancelled
209 // when the protocol is dropped
210 let inner_clone = Arc::clone(&inner);
211 let receive_task_handle =
212 AbortOnDropHandle::new(tokio::spawn(inner_clone.receiving_task(receiver)));
213
214 // Return the protocol
215 Self {
216 inner,
217 _receiving_task_handle: Arc::new(receive_task_handle),
218 }
219 }
220}
221
222/// A type alias for an `Arc<dyn Any + Send + Sync + 'static>`
223type ThreadSafeAny = Arc<dyn Any + Send + Sync + 'static>;
224
225/// A type alias for the future that validates a response
226type ResponseValidationFuture =
227 Pin<Box<dyn Future<Output = Result<ThreadSafeAny, anyhow::Error>> + Send + Sync + 'static>>;
228
229/// A type alias for the function that returns the above future
230type ResponseValidationFn<R> =
231 Box<dyn Fn(&R, <R as Request>::Response) -> ResponseValidationFuture + Send + Sync + 'static>;
232
233/// The inner implementation for the request-response protocol
234pub struct RequestResponseInner<
235 S: Sender<K>,
236 R: Receiver,
237 Req: Request,
238 RS: RecipientSource<Req, K>,
239 DS: DataSource<Req>,
240 K: SignatureKey + 'static,
241> {
242 /// The configuration of the protocol
243 config: RequestResponseConfig,
244 /// The sender to use for the protocol
245 pub sender: S,
246 /// The recipient source to use for the protocol
247 pub recipient_source: RS,
248 /// The data source to use for the protocol
249 data_source: DS,
250 /// The map of currently active, outgoing requests
251 outgoing_requests: OutgoingRequestsMap<Req>,
252 /// Phantom data to help with type inference
253 phantom_data: PhantomData<(K, R, Req, DS)>,
254}
255impl<
256 S: Sender<K>,
257 R: Receiver,
258 Req: Request,
259 RS: RecipientSource<Req, K>,
260 DS: DataSource<Req>,
261 K: SignatureKey + 'static,
262> RequestResponseInner<S, R, Req, RS, DS, K>
263{
264 /// Request something from the protocol indefinitely until we get a response
265 /// or there was a critical error (e.g. the request could not be signed)
266 ///
267 /// # Errors
268 /// - If the request was invalid
269 /// - If there was a critical error (e.g. the channel was closed)
270 pub async fn request_indefinitely<F, Fut, O>(
271 self: &Arc<Self>,
272 public_key: &K,
273 private_key: &K::PrivateKey,
274 // The type of request to make
275 request_type: RequestType,
276 // The estimated TTL of other participants. This is used to decide when to
277 // stop making requests and sign a new one
278 estimated_request_ttl: Duration,
279 // The request to make
280 request: Req,
281 // The response validation function
282 response_validation_fn: F,
283 ) -> std::result::Result<O, RequestError>
284 where
285 F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
286 Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
287 O: Send + Sync + 'static + Clone,
288 {
289 loop {
290 // Sign a request message
291 let request_message = RequestMessage::new_signed(public_key, private_key, &request)
292 .map_err(|e| {
293 RequestError::InvalidRequest(anyhow::anyhow!(
294 "failed to sign request message: {e}"
295 ))
296 })?;
297
298 // Request the data, handling the errors appropriately
299 match self
300 .request(
301 request_message,
302 request_type,
303 estimated_request_ttl,
304 response_validation_fn.clone(),
305 )
306 .await
307 {
308 Ok(response) => return Ok(response),
309 Err(RequestError::Timeout) => continue,
310 Err(e) => return Err(e),
311 }
312 }
313 }
314
315 /// Request something from the protocol and wait for the response. This function
316 /// will join with an existing request for the same data (determined by `Blake3` hash),
317 /// however both will make requests until the timeout is reached
318 ///
319 /// # Errors
320 /// - If the request times out
321 /// - If the channel is closed (this is an internal error)
322 /// - If the request we sign is invalid
323 pub async fn request<F, Fut, O>(
324 self: &Arc<Self>,
325 request_message: RequestMessage<Req, K>,
326 request_type: RequestType,
327 timeout_duration: Duration,
328 response_validation_fn: F,
329 ) -> std::result::Result<O, RequestError>
330 where
331 F: Fn(&Req, Req::Response) -> Fut + Send + Sync + 'static + Clone,
332 Fut: Future<Output = anyhow::Result<O>> + Send + Sync + 'static,
333 O: Send + Sync + 'static + Clone,
334 {
335 timeout(timeout_duration, async move {
336 // Calculate the hash of the request
337 let request_hash = blake3::hash(&request_message.request.to_bytes().map_err(|e| {
338 RequestError::InvalidRequest(anyhow::anyhow!(
339 "failed to serialize request message: {e}"
340 ))
341 })?);
342
343 let request = {
344 // Get a write lock on the outgoing requests map
345 let mut outgoing_requests_write = self.outgoing_requests.write().await;
346
347 // Conditionally get the outgoing request, creating a new one if it doesn't exist or if
348 // the existing one has been dropped and not yet removed
349 if let Some(outgoing_request) = outgoing_requests_write
350 .get(&request_hash)
351 .and_then(Weak::upgrade)
352 {
353 OutgoingRequest(outgoing_request)
354 } else {
355 // Create a new broadcast channel for the response
356 let (sender, receiver) = async_broadcast::broadcast(1);
357
358 // Modify the response validation function to return an `Arc<dyn Any>`
359 let response_validation_fn =
360 Box::new(move |request: &Req, response: Req::Response| {
361 let fut = response_validation_fn(request, response);
362 Box::pin(
363 async move { fut.await.map(|ok| Arc::new(ok) as ThreadSafeAny) },
364 ) as ResponseValidationFuture
365 });
366
367 // Create a new outgoing request
368 let outgoing_request = OutgoingRequest(Arc::new(OutgoingRequestInner {
369 sender,
370 receiver,
371 response_validation_fn,
372 request: request_message.request.clone(),
373 }));
374
375 // Write the new outgoing request to the map
376 outgoing_requests_write
377 .insert(request_hash, Arc::downgrade(&outgoing_request.0));
378
379 // Return the new outgoing request
380 outgoing_request
381 }
382 };
383
384 // Create a request message and serialize it
385 let message = Bytes::from(
386 Message::Request(request_message.clone())
387 .to_bytes()
388 .map_err(|e| {
389 RequestError::InvalidRequest(anyhow::anyhow!(
390 "failed to serialize request message: {e}"
391 ))
392 })?,
393 );
394
395 // Create a place to put the handle for the batched sending task. We need this because
396 // otherwise it gets dropped when the closure goes out of scope, instead of when the function
397 // gets cancelled or returns
398 let mut _batched_sending_task = None;
399
400 // Match on the type of request
401 if request_type == RequestType::Broadcast {
402 trace!("Sending request {request_message:?} to all participants");
403
404 // If the message is a broadcast request, just send it to all participants
405 self.sender
406 .send_broadcast_message(&message)
407 .await
408 .map_err(|e| {
409 RequestError::Other(anyhow::anyhow!(
410 "failed to send broadcast message: {e}"
411 ))
412 })?;
413 } else {
414 // If the message is a batched request, we need to batch it with other requests
415
416 // Get the recipients that the request should expect responses from. Shuffle them so
417 // that we don't always send to the same recipients in the same order
418 let mut recipients = self
419 .recipient_source
420 .get_expected_responders(&request_message.request)
421 .await
422 .map_err(|e| {
423 RequestError::InvalidRequest(anyhow::anyhow!(
424 "failed to get expected responders for request: {e}"
425 ))
426 })?;
427 recipients.shuffle(&mut rand::thread_rng());
428
429 // Get the current time so we can check when the timeout has elapsed
430 let start_time = Instant::now();
431
432 // Spawn a task that sends out requests to the network
433 let self_clone = Arc::clone(self);
434 let batched_sending_handle = AbortOnDropHandle::new(spawn(async move {
435 // Create a bounded queue for the outgoing requests. We use this to make sure
436 // we have less than [`config.request_batch_size`] requests in flight at any time.
437 //
438 // When newer requests are added, older ones are removed from the queue. Because we use
439 // `AbortOnDropHandle`, the older ones will automatically get cancelled
440 let mut outgoing_requests =
441 BoundedVecDeque::new(self_clone.config.request_batch_size);
442
443 // While the timeout hasn't elapsed, send out requests to the network
444 while start_time.elapsed() < timeout_duration {
445 // Send out requests to the network in their own separate tasks
446 for recipient_batch in
447 recipients.chunks(self_clone.config.request_batch_size)
448 {
449 for recipient in recipient_batch {
450 // Clone ourselves, the message, and the recipient so they can be moved
451 let self_clone = Arc::clone(&self_clone);
452 let request_message_clone = request_message.clone();
453 let recipient_clone = recipient.clone();
454 let message_clone = Arc::clone(&message);
455
456 // Spawn the task that sends the request to the participant
457 let individual_sending_task = spawn(async move {
458 trace!(
459 "Sending request {request_message_clone:?} to \
460 {recipient_clone:?}"
461 );
462
463 let _ = self_clone
464 .sender
465 .send_direct_message(&message_clone, recipient_clone)
466 .await;
467 });
468
469 // Add the sending task to the queue
470 outgoing_requests
471 .push(AbortOnDropHandle::new(individual_sending_task));
472 }
473
474 // After we send the batch out, wait the [`config.request_batch_interval`]
475 // before sending the next one
476 sleep(self_clone.config.request_batch_interval).await;
477 }
478 }
479 }));
480
481 // Store the handle so it doesn't get dropped
482 _batched_sending_task = Some(batched_sending_handle);
483 }
484
485 // Wait for a response on the channel
486 request
487 .receiver
488 .clone()
489 .recv()
490 .await
491 .map_err(|_| RequestError::Other(anyhow!("channel was closed")))
492 })
493 .await
494 .map_err(|_| RequestError::Timeout)
495 .and_then(|result| result)
496 .and_then(|result| {
497 result.downcast::<O>().map_err(|e| {
498 RequestError::Other(anyhow::anyhow!(
499 "failed to downcast response to expected type: {e:?}"
500 ))
501 })
502 })
503 .map(|result| Arc::unwrap_or_clone(result))
504 }
505
506 /// The task responsible for receiving messages from the receiver and handling them
507 async fn receiving_task(self: Arc<Self>, mut receiver: R) {
508 // Upper bound the number of outgoing and incoming responses
509 let mut incoming_requests = NamedSemaphore::new(
510 self.config.max_incoming_requests_per_key,
511 Some(self.config.max_incoming_requests),
512 );
513 // process 1 response per key, with a global maximum of [`config.max_incoming_responses`] responses
514 let mut incoming_responses =
515 NamedSemaphore::new(1, Some(self.config.max_incoming_responses));
516
517 // While the receiver is open, we receive messages and handle them
518 loop {
519 // Try to receive a message
520 match receiver.receive_message().await {
521 Ok(message) => {
522 // Deserialize the message, warning if it fails
523 let message = match Message::from_bytes(&message) {
524 Ok(message) => message,
525 Err(e) => {
526 warn!("Received invalid message: {e:#}");
527 continue;
528 },
529 };
530
531 // Handle the message based on its type
532 match message {
533 Message::Request(request_message) => {
534 self.handle_request(request_message, &mut incoming_requests);
535 },
536 Message::Response(response_message) => {
537 self.handle_response(response_message, &mut incoming_responses);
538 },
539 }
540 },
541 // An error here means the receiver will _NEVER_ receive any more messages
542 Err(e) => {
543 error!("Request/response receive task exited: {e:#}");
544 return;
545 },
546 }
547 }
548 }
549
550 /// Handle a request sent to us
551 fn handle_request(
552 self: &Arc<Self>,
553 request_message: RequestMessage<Req, K>,
554 incoming_requests: &mut IncomingRequests<K>,
555 ) {
556 trace!("Handling request {:?}", request_message);
557
558 // Spawn a task to:
559 // - Validate the request
560 // - Derive the response data (check if we have it)
561 // - Send the response to the requester
562 let self_clone = Arc::clone(self);
563
564 // Attempt to acquire a permit for the request. Warn if there are too many requests currently being processed
565 // either globally or per-key
566 let permit = incoming_requests.try_acquire(request_message.public_key.clone());
567 match permit {
568 Ok(ref permit) => permit,
569 Err(NamedSemaphoreError::PerKeyLimitReached) => {
570 info!(
571 "Failed to process request from {}: too many requests from the same key are \
572 already being processed",
573 request_message.public_key
574 );
575 return;
576 },
577 Err(NamedSemaphoreError::GlobalLimitReached) => {
578 info!(
579 "Failed to process request from {}: too many requests are already being \
580 processed",
581 request_message.public_key
582 );
583 return;
584 },
585 };
586
587 tokio::spawn(async move {
588 let result = timeout(self_clone.config.incoming_request_timeout, async move {
589 // Validate the request message. This includes:
590 // - Checking the signature and making sure it's valid
591 // - Checking the timestamp and making sure it's not too old
592 // - Calling the request's application-specific validation function
593 request_message
594 .validate(self_clone.config.incoming_request_ttl)
595 .await
596 .with_context(|| "failed to validate request")?;
597
598 // Try to fetch the response data from the data source
599 let response = self_clone
600 .data_source
601 .derive_response_for(&request_message.request)
602 .await
603 .with_context(|| "failed to derive response for request")?;
604
605 // Create the response message and serialize it
606 let response = Bytes::from(
607 Message::Response::<Req, K>(ResponseMessage {
608 request_hash: blake3::hash(&request_message.request.to_bytes()?),
609 response,
610 })
611 .to_bytes()
612 .with_context(|| "failed to serialize response message")?,
613 );
614
615 // Send the response to the requester
616 self_clone
617 .sender
618 .send_direct_message(&response, request_message.public_key)
619 .await
620 .with_context(|| "failed to send response to requester")?;
621
622 // Drop the permit
623 _ = permit;
624 drop(permit);
625
626 Ok::<(), anyhow::Error>(())
627 })
628 .await
629 .map_err(|_| anyhow::anyhow!("timed out while sending response"))
630 .and_then(|result| result);
631
632 if let Err(e) = result {
633 debug!("Failed to send response to requester: {e:#}");
634 }
635 });
636 }
637
638 async fn wait_for_permit(
639 incoming_responses: &mut IncomingResponses,
640 request_hash: RequestHash,
641 outgoing_requests: &OutgoingRequestsMap<Req>,
642 ) -> Result<NamedSemaphorePermit<RequestHash>, NamedSemaphoreError> {
643 while outgoing_requests.read().await.contains_key(&request_hash) {
644 let permit = incoming_responses.try_acquire(request_hash);
645 let permit = match permit {
646 Ok(permit) => permit,
647 Err(NamedSemaphoreError::PerKeyLimitReached) => {
648 tokio::time::sleep(Duration::from_millis(100)).await;
649 continue;
650 },
651 Err(NamedSemaphoreError::GlobalLimitReached) => {
652 warn!(
653 "Failed to process response: too many responses are already being \
654 processed"
655 );
656 return Err(NamedSemaphoreError::GlobalLimitReached);
657 },
658 };
659 return Ok(permit);
660 }
661 Err(NamedSemaphoreError::GlobalLimitReached)
662 }
663
664 /// Handle a response sent to us
665 fn handle_response(
666 self: &Arc<Self>,
667 response: ResponseMessage<Req>,
668 incoming_responses: &mut IncomingResponses,
669 ) {
670 trace!("Handling response {response:?}");
671
672 let outgoing_requests_clone = self.outgoing_requests.clone();
673 let mut incoming_responses_clone = incoming_responses.clone();
674
675 // Spawn a task to validate the response and send it to the requester (us)
676 let response_validate_timeout = self.config.incoming_response_timeout;
677 tokio::spawn(async move {
678 if timeout(response_validate_timeout, async move {
679 // Attempt to acquire a permit for the request. Warn if there are too many responses currently being processed
680 let permit = Self::wait_for_permit(
681 &mut incoming_responses_clone,
682 response.request_hash,
683 &outgoing_requests_clone,
684 )
685 .await;
686
687 let Ok(permit) = permit else {
688 warn!(
689 "Failed to process response: too many responses are already being \
690 processed"
691 );
692 return;
693 };
694 // Get the entry in the map, ignoring it if it doesn't exist
695 let Some(outgoing_request) = outgoing_requests_clone
696 .read()
697 .await
698 .get(&response.request_hash)
699 .cloned()
700 .and_then(|r| r.upgrade())
701 else {
702 return;
703 };
704 // Make sure the response is valid for the given request
705 let validation_result = match (outgoing_request.response_validation_fn)(
706 &outgoing_request.request,
707 response.response,
708 )
709 .await
710 {
711 Ok(validation_result) => validation_result,
712 Err(e) => {
713 debug!("Received invalid response: {e:#}");
714 return;
715 },
716 };
717
718 // Send the response to the requester (the user of [`RequestResponse::request`])
719 let _ = outgoing_request.sender.try_broadcast(validation_result);
720
721 outgoing_requests_clone
722 .write()
723 .await
724 .remove(&response.request_hash);
725
726 // Drop the permit
727 _ = permit;
728 drop(permit);
729 })
730 .await
731 .is_err()
732 {
733 warn!("Timed out while validating response");
734 }
735 });
736 }
737}
738
739/// An outgoing request. This is what we use to track a request and its corresponding response
740/// in the protocol
741#[derive(Clone, Deref)]
742pub struct OutgoingRequest<R: Request>(Arc<OutgoingRequestInner<R>>);
743
744/// The inner implementation of an outgoing request
745pub struct OutgoingRequestInner<R: Request> {
746 /// The sender to use for the protocol
747 sender: async_broadcast::Sender<ThreadSafeAny>,
748 /// The receiver to use for the protocol
749 receiver: async_broadcast::Receiver<ThreadSafeAny>,
750
751 /// The request that we are waiting for a response to
752 request: R,
753
754 /// The function used to validate the response
755 response_validation_fn: ResponseValidationFn<R>,
756}
757
758#[cfg(test)]
759mod tests {
760 use std::{
761 collections::HashMap,
762 sync::{Mutex, atomic::AtomicBool},
763 };
764
765 use async_trait::async_trait;
766 use hotshot_types::signature_key::{BLSPrivKey, BLSPubKey};
767 use rand::Rng;
768 use tokio::{sync::mpsc, task::JoinSet};
769
770 use super::*;
771
772 /// A test sender that has a list of all the participants in the network
773 #[derive(Clone)]
774 pub struct TestSender {
775 network: Arc<HashMap<BLSPubKey, mpsc::Sender<Bytes>>>,
776 }
777
778 /// An implementation of the [`Sender`] trait for the [`TestSender`] type
779 #[async_trait]
780 impl Sender<BLSPubKey> for TestSender {
781 async fn send_direct_message(&self, message: &Bytes, recipient: BLSPubKey) -> Result<()> {
782 self.network
783 .get(&recipient)
784 .ok_or(anyhow::anyhow!("recipient not found"))?
785 .send(Arc::clone(message))
786 .await
787 .map_err(|_| anyhow::anyhow!("failed to send message"))?;
788
789 Ok(())
790 }
791
792 async fn send_broadcast_message(&self, message: &Bytes) -> Result<()> {
793 for sender in self.network.values() {
794 sender
795 .send(Arc::clone(message))
796 .await
797 .map_err(|_| anyhow::anyhow!("failed to send message"))?;
798 }
799 Ok(())
800 }
801 }
802
803 // Implement the [`RecipientSource`] trait for the [`TestSender`] type
804 #[async_trait]
805 impl RecipientSource<TestRequest, BLSPubKey> for TestSender {
806 async fn get_expected_responders(&self, _request: &TestRequest) -> Result<Vec<BLSPubKey>> {
807 // Get all the participants in the network
808 Ok(self.network.keys().copied().collect())
809 }
810 }
811
812 // Create a test request that is just some bytes
813 #[derive(Clone, Debug)]
814 struct TestRequest(Vec<u8>);
815
816 // Implement the [`Serializable`] trait for the [`TestRequest`] type
817 impl Serializable for TestRequest {
818 fn to_bytes(&self) -> Result<Vec<u8>> {
819 Ok(self.0.clone())
820 }
821
822 fn from_bytes(bytes: &[u8]) -> Result<Self> {
823 Ok(TestRequest(bytes.to_vec()))
824 }
825 }
826
827 // Implement the [`Request`] trait for the [`TestRequest`] type
828 #[async_trait]
829 impl Request for TestRequest {
830 type Response = Vec<u8>;
831 async fn validate(&self) -> Result<()> {
832 Ok(())
833 }
834 }
835
836 // Create a test data source that pretends to have the data or not
837 #[derive(Clone)]
838 struct TestDataSource {
839 /// Whether we have the data or not
840 has_data: bool,
841 /// The time at which the data will be available if we have it
842 data_available_time: Instant,
843
844 /// Whether or not the data will be taken once served
845 take_data: bool,
846 /// Whether or not the data has been taken
847 taken: Arc<AtomicBool>,
848 }
849
850 #[async_trait]
851 impl DataSource<TestRequest> for TestDataSource {
852 async fn derive_response_for(&self, request: &TestRequest) -> Result<Vec<u8>> {
853 // Return a response if we hit the hit rate
854 if self.has_data && Instant::now() >= self.data_available_time {
855 if self.take_data && !self.taken.swap(true, std::sync::atomic::Ordering::Relaxed) {
856 return Err(anyhow::anyhow!("data already taken"));
857 }
858 Ok(blake3::hash(&request.0).as_bytes().to_vec())
859 } else {
860 Err(anyhow::anyhow!("did not have the data"))
861 }
862 }
863 }
864
865 /// Create and return a default protocol configuration
866 fn default_protocol_config() -> RequestResponseConfig {
867 RequestResponseConfig {
868 incoming_request_ttl: Duration::from_secs(40),
869 incoming_request_timeout: Duration::from_secs(40),
870 request_batch_size: 10,
871 request_batch_interval: Duration::from_millis(100),
872 max_incoming_requests: 10,
873 max_incoming_requests_per_key: 1,
874 incoming_response_timeout: Duration::from_secs(1),
875 max_incoming_responses: 5,
876 }
877 }
878
879 /// Create fully connected test networks with `num_participants` participants
880 fn create_participants(
881 num: usize,
882 ) -> Vec<(TestSender, mpsc::Receiver<Bytes>, (BLSPubKey, BLSPrivKey))> {
883 // The entire network
884 let mut network = HashMap::new();
885
886 // All receivers in the network
887 let mut receivers = Vec::new();
888
889 // All keypairs in the network
890 let mut keypairs = Vec::new();
891
892 // For each participant,
893 for i in 0..num {
894 // Create a unique `BLSPubKey`
895 let (public_key, private_key) =
896 BLSPubKey::generated_from_seed_indexed([2; 32], i.try_into().unwrap());
897
898 // Add the keypair to the list
899 keypairs.push((public_key, private_key));
900
901 // Create a channel for sending and receiving messages
902 let (sender, receiver) = mpsc::channel::<Bytes>(100);
903
904 // Add the participant to the network
905 network.insert(public_key, sender);
906
907 // Add the receiver to the list of receivers
908 receivers.push(receiver);
909 }
910
911 // Create a test sender from the network
912 let sender = TestSender {
913 network: Arc::new(network),
914 };
915
916 // Return all senders and receivers
917 receivers
918 .into_iter()
919 .zip(keypairs)
920 .map(|(r, k)| (sender.clone(), r, k))
921 .collect()
922 }
923
924 /// The configuration for an integration test
925 #[derive(Clone)]
926 struct IntegrationTestConfig {
927 /// The request response protocol configuration
928 request_response_config: RequestResponseConfig,
929 /// The number of participants in the network
930 num_participants: usize,
931 /// The number of participants that have the data
932 num_participants_with_data: usize,
933 /// The timeout for the requests
934 request_timeout: Duration,
935 /// The delay before the nodes have the data available
936 data_available_delay: Duration,
937 }
938
939 /// The result of an integration test
940 struct IntegrationTestResult {
941 /// The number of nodes that received a response
942 num_succeeded: usize,
943 }
944
945 /// Run an integration test with the given parameters
946 async fn run_integration_test(config: IntegrationTestConfig) -> IntegrationTestResult {
947 // Create a fully connected network with `num_participants` participants
948 let participants = create_participants(config.num_participants);
949
950 // Create a join set to wait for all the tasks to finish
951 let mut join_set = JoinSet::new();
952
953 // We need to keep these here so they don't get dropped
954 let handles = Arc::new(Mutex::new(Vec::new()));
955
956 // For each one, create a new [`RequestResponse`] protocol
957 for (i, (sender, receiver, (public_key, private_key))) in
958 participants.into_iter().enumerate()
959 {
960 let config_clone = config.request_response_config.clone();
961 let handles_clone = Arc::clone(&handles);
962 join_set.spawn(async move {
963 let protocol = RequestResponse::new(
964 config_clone,
965 sender.clone(),
966 receiver,
967 sender,
968 TestDataSource {
969 has_data: i < config.num_participants_with_data,
970 data_available_time: Instant::now() + config.data_available_delay,
971 take_data: false,
972 taken: Arc::new(AtomicBool::new(false)),
973 },
974 );
975
976 // Add the handle to the handles list so it doesn't get dropped and
977 // cancelled
978 #[allow(clippy::used_underscore_binding)]
979 handles_clone
980 .lock()
981 .unwrap()
982 .push(Arc::clone(&protocol._receiving_task_handle));
983
984 // Create a random request
985 let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
986
987 // Get the hash of the request
988 let request_hash = blake3::hash(&request.0).as_bytes().to_vec();
989
990 // Create a new request message
991 let request = RequestMessage::new_signed(&public_key, &private_key, &request)
992 .expect("failed to create request message");
993
994 // Request the data from the protocol
995 let response = protocol
996 .request(
997 request,
998 RequestType::Batched,
999 config.request_timeout,
1000 |_request, response| async move { Ok(response) },
1001 )
1002 .await?;
1003
1004 // Make sure the response is the hash of the request
1005 assert_eq!(response, request_hash);
1006
1007 Ok::<(), anyhow::Error>(())
1008 });
1009 }
1010
1011 // Wait for all the tasks to finish
1012 let mut num_succeeded = config.num_participants;
1013 while let Some(result) = join_set.join_next().await {
1014 if result.is_err() || result.unwrap().is_err() {
1015 num_succeeded -= 1;
1016 }
1017 }
1018
1019 IntegrationTestResult { num_succeeded }
1020 }
1021
1022 /// Test the integration of the protocol with 50% of the participants having the data
1023 #[tokio::test(flavor = "multi_thread")]
1024 async fn test_integration_50_0s() {
1025 // Build a config
1026 let config = IntegrationTestConfig {
1027 request_response_config: default_protocol_config(),
1028 num_participants: 100,
1029 num_participants_with_data: 50,
1030 request_timeout: Duration::from_secs(40),
1031 data_available_delay: Duration::from_secs(0),
1032 };
1033
1034 // Run the test, making sure all the requests succeed
1035 let result = run_integration_test(config).await;
1036 assert_eq!(result.num_succeeded, 100);
1037 }
1038
1039 /// Test the integration of the protocol when nobody has the data. Make sure we don't
1040 /// get any responses
1041 #[tokio::test(flavor = "multi_thread")]
1042 async fn test_integration_0() {
1043 // Build a config
1044 let config = IntegrationTestConfig {
1045 request_response_config: default_protocol_config(),
1046 num_participants: 100,
1047 num_participants_with_data: 0,
1048 request_timeout: Duration::from_secs(40),
1049 data_available_delay: Duration::from_secs(0),
1050 };
1051
1052 // Run the test
1053 let result = run_integration_test(config).await;
1054
1055 // Make sure all the requests succeeded
1056 assert_eq!(result.num_succeeded, 0);
1057 }
1058
1059 /// Test the integration of the protocol when one node has the data after
1060 /// a delay of 1s
1061 #[tokio::test(flavor = "multi_thread")]
1062 async fn test_integration_1_1s() {
1063 // Build a config
1064 let config = IntegrationTestConfig {
1065 request_response_config: default_protocol_config(),
1066 num_participants: 100,
1067 num_participants_with_data: 1,
1068 request_timeout: Duration::from_secs(40),
1069 data_available_delay: Duration::from_secs(2),
1070 };
1071
1072 // Run the test
1073 let result = run_integration_test(config).await;
1074
1075 // Make sure all the requests succeeded
1076 assert_eq!(result.num_succeeded, 100);
1077 }
1078
1079 /// Test that we can join an existing request for the same data and get the same (single) response
1080 #[tokio::test(flavor = "multi_thread")]
1081 async fn test_join_existing_request() {
1082 // Build a config
1083 let config = default_protocol_config();
1084
1085 // Create two participants
1086 let mut participants = Vec::new();
1087
1088 for (sender, receiver, (public_key, private_key)) in create_participants(2) {
1089 // For each, create a new [`RequestResponse`] protocol
1090 let protocol = RequestResponse::new(
1091 config.clone(),
1092 sender.clone(),
1093 receiver,
1094 sender,
1095 TestDataSource {
1096 take_data: true,
1097 has_data: true,
1098 data_available_time: Instant::now() + Duration::from_secs(2),
1099 taken: Arc::new(AtomicBool::new(false)),
1100 },
1101 );
1102
1103 // Add the participants to the list
1104 participants.push((protocol, public_key, private_key));
1105 }
1106
1107 // Take the first participant
1108 let one = Arc::new(participants.remove(0));
1109
1110 // Create the request that they should all be able to join on
1111 let request = TestRequest(vec![rand::thread_rng().r#gen(); 100]);
1112
1113 // Create a join set to wait for all the tasks to finish
1114 let mut join_set = JoinSet::new();
1115
1116 // Make 10 requests with the same hash
1117 for _ in 0..10 {
1118 // Clone the first participant
1119 let one_clone = Arc::clone(&one);
1120
1121 // Clone the request
1122 let request_clone = request.clone();
1123
1124 // Spawn a task to request the data
1125 join_set.spawn(async move {
1126 // Create a new, signed request message
1127 let request_message =
1128 RequestMessage::new_signed(&one_clone.1, &one_clone.2, &request_clone)?;
1129
1130 // Start requesting it
1131 one_clone
1132 .0
1133 .request(
1134 request_message,
1135 RequestType::Batched,
1136 Duration::from_secs(20),
1137 |_request, response| async move { Ok(response) },
1138 )
1139 .await?;
1140
1141 Ok::<(), anyhow::Error>(())
1142 });
1143 }
1144
1145 // Wait for all the tasks to finish, making sure they all succeed
1146 while let Some(result) = join_set.join_next().await {
1147 result
1148 .expect("failed to join task")
1149 .expect("failed to request data");
1150 }
1151 }
1152}