Skip to main content

hotshot/traits/networking/
memory_network.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7//! In memory network simulator
8//!
9//! This module provides an in-memory only simulation of an actual network, useful for unit and
10//! integration tests.
11
12use core::time::Duration;
13use std::{
14    collections::HashMap,
15    fmt::Debug,
16    sync::{
17        Arc,
18        atomic::{AtomicUsize, Ordering},
19    },
20};
21
22use async_lock::{Mutex, RwLock};
23use async_trait::async_trait;
24use dashmap::DashMap;
25use hotshot_types::{
26    BoxSyncFuture, PeerConnectInfo, boxed_sync,
27    data::ViewNumber,
28    traits::{
29        network::{
30            AsyncGenerator, BroadcastDelay, ConnectedNetwork, TestableNetworkingImplementation,
31            Topic,
32        },
33        node_implementation::NodeType,
34        signature_key::SignatureKey,
35    },
36};
37use tokio::{
38    spawn,
39    sync::mpsc::{Receiver, Sender, channel, error::SendError},
40};
41use tracing::{Instrument, debug, error, info, info_span, instrument, trace, warn};
42
43use super::{NetworkError, NetworkReliability};
44
45/// Shared state for in-memory mock networking.
46///
47/// This type is responsible for keeping track of the channels to each [`MemoryNetwork`], and is
48/// used to group the [`MemoryNetwork`] instances.
49#[derive(derive_more::Debug)]
50pub struct MasterMap<K: SignatureKey> {
51    /// The list of `MemoryNetwork`s
52    #[debug(skip)]
53    map: DashMap<K, MemoryNetwork<K>>,
54
55    /// The list of `MemoryNetwork`s aggregated by topic
56    subscribed_map: DashMap<Topic, Vec<(K, MemoryNetwork<K>)>>,
57}
58
59impl<K: SignatureKey> MasterMap<K> {
60    /// Create a new, empty, `MasterMap`
61    #[must_use]
62    pub fn new() -> Arc<MasterMap<K>> {
63        Arc::new(MasterMap {
64            map: DashMap::new(),
65            subscribed_map: DashMap::new(),
66        })
67    }
68}
69
70/// Internal state for a `MemoryNetwork` instance
71#[derive(Debug)]
72struct MemoryNetworkInner<K: SignatureKey> {
73    /// Input for messages
74    input: RwLock<Option<Sender<Vec<u8>>>>,
75    /// Output for messages
76    output: Mutex<Receiver<Vec<u8>>>,
77    /// The master map
78    master_map: Arc<MasterMap<K>>,
79
80    /// Count of messages that are in-flight (send but not processed yet)
81    in_flight_message_count: AtomicUsize,
82
83    /// config to introduce unreliability to the network
84    reliability_config: Option<Box<dyn NetworkReliability>>,
85}
86
87/// In memory only network simulator.
88///
89/// This provides an in memory simulation of a networking implementation, allowing nodes running on
90/// the same machine to mock networking while testing other functionality.
91///
92/// Under the hood, this simply maintains mpmc channels to every other `MemoryNetwork` instance of the
93/// same group.
94#[derive(Clone)]
95pub struct MemoryNetwork<K: SignatureKey> {
96    /// The actual internal state
97    inner: Arc<MemoryNetworkInner<K>>,
98}
99
100impl<K: SignatureKey> Debug for MemoryNetwork<K> {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("MemoryNetwork")
103            .field("inner", &"inner")
104            .finish()
105    }
106}
107
108impl<K: SignatureKey> MemoryNetwork<K> {
109    /// Creates a new `MemoryNetwork` and hooks it up to the group through the provided `MasterMap`
110    pub fn new(
111        pub_key: &K,
112        master_map: &Arc<MasterMap<K>>,
113        subscribed_topics: &[Topic],
114        reliability_config: Option<Box<dyn NetworkReliability>>,
115    ) -> MemoryNetwork<K> {
116        info!("Attaching new MemoryNetwork");
117        let (input, mut task_recv) = channel(128);
118        let (task_send, output) = channel(128);
119        let in_flight_message_count = AtomicUsize::new(0);
120        trace!("Channels open, spawning background task");
121
122        spawn(
123            async move {
124                debug!("Starting background task");
125                trace!("Entering processing loop");
126                while let Some(vec) = task_recv.recv().await {
127                    trace!(?vec, "Incoming message");
128                    // Attempt to decode message
129                    let ts = task_send.clone();
130                    let res = ts.send(vec).await;
131                    if res.is_ok() {
132                        trace!("Passed message to output queue");
133                    } else {
134                        error!("Output queue receivers are shutdown");
135                    }
136                }
137            }
138            .instrument(info_span!("MemoryNetwork Background task", map = ?master_map)),
139        );
140        trace!("Notifying other networks of the new connected peer");
141        trace!("Task spawned, creating MemoryNetwork");
142        let mn = MemoryNetwork {
143            inner: Arc::new(MemoryNetworkInner {
144                input: RwLock::new(Some(input)),
145                output: Mutex::new(output),
146                master_map: Arc::clone(master_map),
147                in_flight_message_count,
148                reliability_config,
149            }),
150        };
151        // Insert our public key into the master map
152        master_map.map.insert(pub_key.clone(), mn.clone());
153        // Insert our subscribed topics into the master map
154        for topic in subscribed_topics {
155            master_map
156                .subscribed_map
157                .entry(*topic)
158                .or_default()
159                .push((pub_key.clone(), mn.clone()));
160        }
161
162        mn
163    }
164
165    /// Send a [`Vec<u8>`] message to the inner `input`
166    async fn input(&self, message: Vec<u8>) -> Result<(), SendError<Vec<u8>>> {
167        self.inner
168            .in_flight_message_count
169            .fetch_add(1, Ordering::Relaxed);
170        let input = self.inner.input.read().await;
171        if let Some(input) = &*input {
172            input.send(message).await
173        } else {
174            Err(SendError(message))
175        }
176    }
177}
178
179impl<TYPES: NodeType> TestableNetworkingImplementation<TYPES>
180    for MemoryNetwork<TYPES::SignatureKey>
181{
182    fn generator(
183        _expected_node_count: usize,
184        _num_bootstrap: usize,
185        _network_id: usize,
186        da_committee_size: usize,
187        reliability_config: Option<Box<dyn NetworkReliability>>,
188        _secondary_network_delay: Duration,
189        _connect_infos: &mut HashMap<TYPES::SignatureKey, PeerConnectInfo>,
190    ) -> AsyncGenerator<Arc<Self>> {
191        let master: Arc<_> = MasterMap::new();
192        // We assign known_nodes' public key and stake value rather than read from config file since it's a test
193        Box::pin(move |node_id| {
194            let privkey = TYPES::SignatureKey::generated_from_seed_indexed([0u8; 32], node_id).1;
195            let pubkey = TYPES::SignatureKey::from_private(&privkey);
196
197            // Subscribe to topics based on our index
198            let subscribed_topics = if node_id < da_committee_size as u64 {
199                // DA node
200                vec![Topic::Da, Topic::Global]
201            } else {
202                // Non-DA node
203                vec![Topic::Global]
204            };
205
206            let net = MemoryNetwork::new(
207                &pubkey,
208                &master,
209                &subscribed_topics,
210                reliability_config.clone(),
211            );
212            Box::pin(async move { net.into() })
213        })
214    }
215
216    fn in_flight_message_count(&self) -> Option<usize> {
217        Some(self.inner.in_flight_message_count.load(Ordering::Relaxed))
218    }
219}
220
221// TODO instrument these functions
222#[async_trait]
223impl<K: SignatureKey + 'static> ConnectedNetwork<K> for MemoryNetwork<K> {
224    #[instrument(name = "MemoryNetwork::ready_blocking")]
225    async fn wait_for_ready(&self) {}
226
227    fn pause(&self) {
228        unimplemented!("Pausing not implemented for the Memory network");
229    }
230
231    fn resume(&self) {
232        unimplemented!("Resuming not implemented for the Memory network");
233    }
234
235    #[instrument(name = "MemoryNetwork::shut_down")]
236    fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()>
237    where
238        'a: 'b,
239        Self: 'b,
240    {
241        let closure = async move {
242            *self.inner.input.write().await = None;
243        };
244        boxed_sync(closure)
245    }
246
247    #[instrument(name = "MemoryNetwork::broadcast_message")]
248    async fn broadcast_message(
249        &self,
250        _: ViewNumber,
251        message: Vec<u8>,
252        topic: Topic,
253        _broadcast_delay: BroadcastDelay,
254    ) -> Result<(), NetworkError> {
255        trace!(?message, "Broadcasting message");
256        // Snapshot the recipient list and release the DashMap shard write
257        // lock before awaiting. Holding the lock across `.await` serializes
258        // every broadcast on this topic and, under a small tokio runtime
259        // (e.g. CI with a 4-core CPU quota), can deadlock all worker threads
260        // while waiting on a full recipient channel.
261        let nodes: Vec<(K, MemoryNetwork<K>)> = self
262            .inner
263            .master_map
264            .subscribed_map
265            .get(&topic)
266            .map(|entry| entry.value().clone())
267            .unwrap_or_default();
268        // Spawn per-recipient sends so a slow/full recipient channel does
269        // not backpressure the rest of the broadcast. When a node restarts
270        // its new coordinator starts draining its input channel from
271        // scratch; an awaited sequential send could block the sender on a
272        // full 128-slot channel, stalling consensus for other peers.
273        for (key, node) in nodes {
274            trace!(?key, "Sending message to node");
275            if let Some(config) = &self.inner.reliability_config {
276                let node2 = node.clone();
277                let fut = config.chaos_send_msg(
278                    message.clone(),
279                    Arc::new(move |msg: Vec<u8>| {
280                        let node3 = (node2).clone();
281                        boxed_sync(async move {
282                            let _res = node3.input(msg).await;
283                            // NOTE we're dropping metrics here but this is only for testing
284                            // purposes. I think that should be okay
285                        })
286                    }),
287                );
288                spawn(fut);
289            } else {
290                let msg = message.clone();
291                spawn(async move {
292                    if let Err(e) = node.input(msg).await {
293                        warn!(?e, ?key, "Error sending broadcast message to node");
294                    }
295                });
296            }
297        }
298        Ok(())
299    }
300
301    #[instrument(name = "MemoryNetwork::da_broadcast_message")]
302    async fn da_broadcast_message(
303        &self,
304        _: ViewNumber,
305        message: Vec<u8>,
306        recipients: Vec<K>,
307        _broadcast_delay: BroadcastDelay,
308    ) -> Result<(), NetworkError> {
309        trace!(?message, "Broadcasting message to DA");
310        // See `broadcast_message`: clone the recipient list out of DashMap
311        // so the shard lock is not held across the awaits below.
312        let nodes: Vec<(K, MemoryNetwork<K>)> = self
313            .inner
314            .master_map
315            .subscribed_map
316            .get(&Topic::Da)
317            .map(|entry| entry.value().clone())
318            .unwrap_or_default();
319        // Spawn per-recipient sends so a slow/full recipient channel does
320        // not backpressure the rest of the broadcast (see `broadcast_message`).
321        for (key, node) in nodes {
322            if !recipients.contains(&key) {
323                tracing::trace!("Skipping node because not in recipient list: {:?}", key);
324                continue;
325            }
326            trace!(?key, "Sending message to node");
327            if let Some(config) = &self.inner.reliability_config {
328                let node2 = node.clone();
329                let fut = config.chaos_send_msg(
330                    message.clone(),
331                    Arc::new(move |msg: Vec<u8>| {
332                        let node3 = (node2).clone();
333                        boxed_sync(async move {
334                            let _res = node3.input(msg).await;
335                            // NOTE we're dropping metrics here but this is only for testing
336                            // purposes. I think that should be okay
337                        })
338                    }),
339                );
340                spawn(fut);
341            } else {
342                let msg = message.clone();
343                spawn(async move {
344                    if let Err(e) = node.input(msg).await {
345                        warn!(?e, ?key, "Error sending broadcast message to node");
346                    }
347                });
348            }
349        }
350        Ok(())
351    }
352
353    #[instrument(name = "MemoryNetwork::direct_message")]
354    async fn direct_message(
355        &self,
356        _: ViewNumber,
357        message: Vec<u8>,
358        recipient: K,
359    ) -> Result<(), NetworkError> {
360        // debug!(?message, ?recipient, "Sending direct message");
361        // Bincode the message
362        trace!("Message bincoded, finding recipient");
363        // Clone the target network and drop the DashMap read guard before
364        // awaiting, matching the rationale in `broadcast_message`.
365        let node = self
366            .inner
367            .master_map
368            .map
369            .get(&recipient)
370            .map(|entry| entry.value().clone());
371        if let Some(node) = node {
372            if let Some(config) = &self.inner.reliability_config {
373                {
374                    let fut = config.chaos_send_msg(
375                        message.clone(),
376                        Arc::new(move |msg: Vec<u8>| {
377                            let node2 = node.clone();
378                            boxed_sync(async move {
379                                let _res = node2.input(msg).await;
380                                // NOTE we're dropping metrics here but this is only for testing
381                                // purposes. I think that should be okay
382                            })
383                        }),
384                    );
385                    spawn(fut);
386                }
387                Ok(())
388            } else {
389                let res = node.input(message).await;
390                match res {
391                    Ok(()) => {
392                        trace!(?recipient, "Delivered message to remote");
393                        Ok(())
394                    },
395                    Err(e) => Err(NetworkError::MessageSendError(format!(
396                        "error sending direct message to node: {e}",
397                    ))),
398                }
399            }
400        } else {
401            Err(NetworkError::MessageSendError(
402                "node does not exist".to_string(),
403            ))
404        }
405    }
406
407    /// Receive one or many messages from the underlying network.
408    ///
409    /// # Errors
410    /// If the other side of the channel is closed
411    #[instrument(name = "MemoryNetwork::recv_messages", skip_all)]
412    async fn recv_message(&self) -> Result<Vec<u8>, NetworkError> {
413        let ret = self
414            .inner
415            .output
416            .lock()
417            .await
418            .recv()
419            .await
420            .ok_or(NetworkError::ShutDown)?;
421        self.inner
422            .in_flight_message_count
423            .fetch_sub(1, Ordering::Relaxed);
424        Ok(ret)
425    }
426}