Skip to main content

hotshot_libp2p_networking/network/node/
handle.rs

1// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
2// This file is part of the HotShot repository.
3
4// You should have received a copy of the MIT License
5// along with the HotShot repository. If not, see <https://mit-license.org/>.
6
7use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration};
8
9use bimap::BiMap;
10use hotshot_types::traits::{
11    network::NetworkError, node_implementation::NodeType, signature_key::SignatureKey,
12};
13use libp2p::{Multiaddr, request_response::ResponseChannel};
14use libp2p_identity::PeerId;
15use parking_lot::Mutex;
16use tokio::{
17    sync::mpsc::{Receiver, UnboundedReceiver, UnboundedSender},
18    time::{sleep, timeout},
19};
20use tracing::{debug, info, instrument};
21
22use crate::network::{
23    ClientRequest, NetworkEvent, NetworkNode, NetworkNodeConfig,
24    behaviours::dht::{
25        record::{Namespace, RecordKey, RecordValue},
26        store::persistent::DhtPersistentStorage,
27    },
28    gen_multiaddr, log_summary,
29};
30
31/// A handle containing:
32/// - A reference to the state
33/// - Controls for the swarm
34#[derive(Debug, Clone)]
35pub struct NetworkNodeHandle<T: NodeType> {
36    /// network configuration
37    network_config: NetworkNodeConfig,
38
39    /// send an action to the networkbehaviour
40    send_network: UnboundedSender<ClientRequest>,
41
42    /// The map from consensus keys to peer IDs
43    consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
44
45    /// the local address we're listening on
46    listen_addr: Multiaddr,
47
48    /// the peer id of the networkbehaviour
49    peer_id: PeerId,
50
51    /// human readable id
52    id: usize,
53}
54
55/// internal network node receiver
56#[derive(Debug)]
57pub struct NetworkNodeReceiver {
58    /// the receiver
59    receiver: UnboundedReceiver<NetworkEvent>,
60
61    ///kill switch
62    recv_kill: Option<Receiver<()>>,
63}
64
65impl NetworkNodeReceiver {
66    /// recv a network event
67    /// # Errors
68    /// Errors if the receiver channel is closed
69    pub async fn recv(&mut self) -> Result<NetworkEvent, NetworkError> {
70        self.receiver
71            .recv()
72            .await
73            .ok_or(NetworkError::ChannelReceiveError(
74                "Receiver channel closed".to_string(),
75            ))
76    }
77    /// Add a kill switch to the receiver
78    pub fn set_kill_switch(&mut self, kill_switch: Receiver<()>) {
79        self.recv_kill = Some(kill_switch);
80    }
81
82    /// Take the kill switch to allow killing the receiver task
83    pub fn take_kill_switch(&mut self) -> Option<Receiver<()>> {
84        self.recv_kill.take()
85    }
86}
87
88/// Spawn a network node task task and return the handle and the receiver for it
89/// # Errors
90/// Errors if spawning the task fails
91pub async fn spawn_network_node<T: NodeType, D: DhtPersistentStorage>(
92    config: NetworkNodeConfig,
93    dht_persistent_storage: D,
94    consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
95    id: usize,
96) -> Result<(NetworkNodeReceiver, NetworkNodeHandle<T>), NetworkError> {
97    let mut network: NetworkNode<T, _> = NetworkNode::new(
98        config.clone(),
99        dht_persistent_storage,
100        Arc::clone(&consensus_key_to_pid_map),
101    )
102    .await
103    .map_err(|e| NetworkError::ConfigError(format!("failed to create network node: {e}")))?;
104    // randomly assigned port
105    let listen_addr = config
106        .bind_address
107        .clone()
108        .unwrap_or_else(|| gen_multiaddr(0));
109    let peer_id = network.peer_id();
110    let listen_addr = network.start_listen(listen_addr).await.map_err(|e| {
111        NetworkError::ListenError(format!("failed to start listening on Libp2p: {e}"))
112    })?;
113    // pin here to force the future onto the heap since it can be large
114    // in the case of flume
115    let (send_chan, recv_chan) = network.spawn_listeners().map_err(|err| {
116        NetworkError::ListenError(format!("failed to spawn listeners for Libp2p: {err}"))
117    })?;
118    log_summary::spawn_summary_task();
119    let receiver = NetworkNodeReceiver {
120        receiver: recv_chan,
121        recv_kill: None,
122    };
123
124    let handle = NetworkNodeHandle::<T> {
125        network_config: config,
126        send_network: send_chan,
127        consensus_key_to_pid_map,
128        listen_addr,
129        peer_id,
130        id,
131    };
132    Ok((receiver, handle))
133}
134
135impl<T: NodeType> NetworkNodeHandle<T> {
136    /// Cleanly shuts down a swarm node
137    /// This is done by sending a message to
138    /// the swarm itself to spin down
139    #[instrument]
140    pub async fn shutdown(&self) -> Result<(), NetworkError> {
141        self.send_request(ClientRequest::Shutdown)?;
142        Ok(())
143    }
144    /// Notify the network to begin the bootstrap process
145    /// # Errors
146    /// If unable to send via `send_network`. This should only happen
147    /// if the network is shut down.
148    pub fn begin_bootstrap(&self) -> Result<(), NetworkError> {
149        let req = ClientRequest::BeginBootstrap;
150        self.send_request(req)
151    }
152
153    /// Get a reference to the network node handle's listen addr.
154    #[must_use]
155    pub fn listen_addr(&self) -> Multiaddr {
156        self.listen_addr.clone()
157    }
158
159    /// Print out the routing table used by kademlia
160    /// NOTE: only for debugging purposes currently
161    /// # Errors
162    /// if the client has stopped listening for a response
163    pub async fn print_routing_table(&self) -> Result<(), NetworkError> {
164        let (s, r) = futures::channel::oneshot::channel();
165        let req = ClientRequest::GetRoutingTable(s);
166        self.send_request(req)?;
167        r.await
168            .map_err(|e| NetworkError::ChannelReceiveError(e.to_string()))
169    }
170    /// Wait until at least `num_peers` have connected
171    ///
172    /// # Errors
173    /// If the channel closes before the result can be sent back
174    pub async fn wait_to_connect(
175        &self,
176        num_required_peers: usize,
177        node_id: usize,
178    ) -> Result<(), NetworkError> {
179        // Wait for the required number of peers to connect
180        loop {
181            // Get the number of currently connected peers
182            let num_connected = self.num_connected().await?;
183            if num_connected >= num_required_peers {
184                break;
185            }
186
187            // Log the number of connected peers
188            info!(
189                "Node {} connected to {}/{} peers",
190                node_id, num_connected, num_required_peers
191            );
192
193            // Sleep for a second before checking again
194            sleep(Duration::from_secs(1)).await;
195        }
196
197        Ok(())
198    }
199
200    /// Look up a peer's addresses in kademlia
201    /// NOTE: this should always be called before any `request_response` is initiated
202    /// # Errors
203    /// if the client has stopped listening for a response
204    pub async fn lookup_pid(&self, peer_id: PeerId) -> Result<(), NetworkError> {
205        let (s, r) = futures::channel::oneshot::channel();
206        let req = ClientRequest::LookupPeer(peer_id, s);
207        self.send_request(req)?;
208        r.await
209            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
210    }
211
212    /// Looks up a node's `PeerId` by its consensus key.
213    ///
214    /// # Errors
215    /// If the DHT lookup fails
216    pub async fn lookup_node(
217        &self,
218        consensus_key: &T::SignatureKey,
219        dht_timeout: Duration,
220    ) -> Result<PeerId, NetworkError> {
221        // First check if we already have an open connection to the peer
222        if let Some(pid) = self
223            .consensus_key_to_pid_map
224            .lock()
225            .get_by_left(consensus_key)
226        {
227            return Ok(*pid);
228        }
229
230        // Create the record key
231        let key = RecordKey::new(Namespace::Lookup, consensus_key.to_bytes());
232
233        // Get the record from the DHT
234        let pid = self.get_record_timeout(key, dht_timeout).await?;
235
236        PeerId::from_bytes(&pid).map_err(|err| NetworkError::FailedToDeserialize(err.to_string()))
237    }
238
239    /// Insert a record into the kademlia DHT
240    /// # Errors
241    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
242    pub async fn put_record(
243        &self,
244        key: RecordKey,
245        value: RecordValue<T::SignatureKey>,
246    ) -> Result<(), NetworkError> {
247        // Serialize the key
248        let key = key.to_bytes();
249
250        // Serialize the record
251        let value = bincode::serialize(&value)
252            .map_err(|e| NetworkError::FailedToSerialize(e.to_string()))?;
253
254        let (s, r) = futures::channel::oneshot::channel();
255        let req = ClientRequest::PutDHT {
256            key: key.clone(),
257            value,
258            notify: s,
259        };
260
261        self.send_request(req)?;
262
263        r.await.map_err(|_| NetworkError::RequestCancelled)
264    }
265
266    /// Receive a record from the kademlia DHT if it exists.
267    /// Must be replicated on at least 2 nodes
268    /// # Errors
269    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key
270    /// - Will return [`NetworkError::FailedToDeserialize`] when unable to deserialize the returned value
271    pub async fn get_record(
272        &self,
273        key: RecordKey,
274        retry_count: u8,
275    ) -> Result<Vec<u8>, NetworkError> {
276        // Serialize the key
277        let serialized_key = key.to_bytes();
278
279        let (s, r) = futures::channel::oneshot::channel();
280        let req = ClientRequest::GetDHT {
281            key: serialized_key.clone(),
282            notify: vec![s],
283            retry_count,
284        };
285        self.send_request(req)?;
286
287        // Map the error
288        let result = r.await.map_err(|_| NetworkError::RequestCancelled)?;
289
290        // Deserialize the record's value
291        let record: RecordValue<T::SignatureKey> = bincode::deserialize(&result)
292            .map_err(|e| NetworkError::FailedToDeserialize(e.to_string()))?;
293
294        Ok(record.value().to_vec())
295    }
296
297    /// Get a record from the kademlia DHT with a timeout
298    /// # Errors
299    /// - Will return [`NetworkError::Timeout`] when times out
300    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
301    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
302    pub async fn get_record_timeout(
303        &self,
304        key: RecordKey,
305        timeout_duration: Duration,
306    ) -> Result<Vec<u8>, NetworkError> {
307        timeout(timeout_duration, self.get_record(key, 3))
308            .await
309            .map_err(|err| NetworkError::Timeout(err.to_string()))?
310    }
311
312    /// Insert a record into the kademlia DHT with a timeout
313    /// # Errors
314    /// - Will return [`NetworkError::Timeout`] when times out
315    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize the key or value
316    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
317    pub async fn put_record_timeout(
318        &self,
319        key: RecordKey,
320        value: RecordValue<T::SignatureKey>,
321        timeout_duration: Duration,
322    ) -> Result<(), NetworkError> {
323        timeout(timeout_duration, self.put_record(key, value))
324            .await
325            .map_err(|err| NetworkError::Timeout(err.to_string()))?
326    }
327
328    /// Subscribe to a topic
329    /// # Errors
330    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
331    pub async fn subscribe(&self, topic: String) -> Result<(), NetworkError> {
332        let (s, r) = futures::channel::oneshot::channel();
333        let req = ClientRequest::Subscribe(topic, Some(s));
334        self.send_request(req)?;
335        r.await
336            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
337    }
338
339    /// Unsubscribe from a topic
340    /// # Errors
341    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
342    pub async fn unsubscribe(&self, topic: String) -> Result<(), NetworkError> {
343        let (s, r) = futures::channel::oneshot::channel();
344        let req = ClientRequest::Unsubscribe(topic, Some(s));
345        self.send_request(req)?;
346        r.await
347            .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
348    }
349
350    /// Ignore `peers` when pruning
351    /// e.g. maintain their connection
352    /// # Errors
353    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
354    pub fn ignore_peers(&self, peers: Vec<PeerId>) -> Result<(), NetworkError> {
355        let req = ClientRequest::IgnorePeers(peers);
356        self.send_request(req)
357    }
358
359    /// Make a direct request to `peer_id` containing `msg`
360    /// # Errors
361    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
362    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
363    pub fn direct_request(&self, pid: PeerId, msg: &[u8]) -> Result<(), NetworkError> {
364        self.direct_request_no_serialize(pid, msg.to_vec())
365    }
366
367    /// Make a direct request to `peer_id` containing `msg` without serializing
368    /// # Errors
369    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
370    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
371    pub fn direct_request_no_serialize(
372        &self,
373        pid: PeerId,
374        contents: Vec<u8>,
375    ) -> Result<(), NetworkError> {
376        let req = ClientRequest::DirectRequest {
377            pid,
378            contents,
379            retry_count: 1,
380        };
381        self.send_request(req)
382    }
383
384    /// Reply with `msg` to a request over `chan`
385    /// # Errors
386    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
387    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
388    pub fn direct_response(
389        &self,
390        chan: ResponseChannel<Vec<u8>>,
391        msg: &[u8],
392    ) -> Result<(), NetworkError> {
393        let req = ClientRequest::DirectResponse(chan, msg.to_vec());
394        self.send_request(req)
395    }
396
397    /// Forcefully disconnect from a peer
398    /// # Errors
399    /// If the channel is closed somehow
400    /// Shouldnt' happen.
401    /// # Panics
402    /// If channel errors out
403    /// shouldn't happen.
404    pub fn prune_peer(&self, pid: PeerId) -> Result<(), NetworkError> {
405        let req = ClientRequest::Prune(pid);
406        self.send_request(req)
407    }
408
409    /// Gossip a message to peers
410    /// # Errors
411    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
412    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
413    pub fn gossip(&self, topic: String, msg: &[u8]) -> Result<(), NetworkError> {
414        self.gossip_no_serialize(topic, msg.to_vec())
415    }
416
417    /// Gossip a message to peers without serializing
418    /// # Errors
419    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
420    /// - Will return [`NetworkError::FailedToSerialize`] when unable to serialize `msg`
421    pub fn gossip_no_serialize(&self, topic: String, msg: Vec<u8>) -> Result<(), NetworkError> {
422        let req = ClientRequest::GossipMsg(topic, msg);
423        self.send_request(req)
424    }
425
426    /// Tell libp2p about known network nodes
427    /// # Errors
428    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
429    pub fn add_known_peers(
430        &self,
431        known_peers: Vec<(PeerId, Multiaddr)>,
432    ) -> Result<(), NetworkError> {
433        debug!("Adding {} known peers", known_peers.len());
434        let req = ClientRequest::AddKnownPeers(known_peers);
435        self.send_request(req)
436    }
437
438    /// Send a client request to the network
439    ///
440    /// # Errors
441    /// - Will return [`NetworkError::ChannelSendError`] when underlying `NetworkNode` has been killed
442    fn send_request(&self, req: ClientRequest) -> Result<(), NetworkError> {
443        self.send_network
444            .send(req)
445            .map_err(|err| NetworkError::ChannelSendError(err.to_string()))
446    }
447
448    /// Returns number of peers this node is connected to
449    /// # Errors
450    /// If the channel is closed somehow
451    /// Shouldnt' happen.
452    /// # Panics
453    /// If channel errors out
454    /// shouldn't happen.
455    pub async fn num_connected(&self) -> Result<usize, NetworkError> {
456        let (s, r) = futures::channel::oneshot::channel();
457        let req = ClientRequest::GetConnectedPeerNum(s);
458        self.send_request(req)?;
459        Ok(r.await.unwrap())
460    }
461
462    /// return hashset of PIDs this node is connected to
463    /// # Errors
464    /// If the channel is closed somehow
465    /// Shouldnt' happen.
466    /// # Panics
467    /// If channel errors out
468    /// shouldn't happen.
469    pub async fn connected_pids(&self) -> Result<HashSet<PeerId>, NetworkError> {
470        let (s, r) = futures::channel::oneshot::channel();
471        let req = ClientRequest::GetConnectedPeers(s);
472        self.send_request(req)?;
473        Ok(r.await.unwrap())
474    }
475
476    /// Get a reference to the network node handle's id.
477    #[must_use]
478    pub fn id(&self) -> usize {
479        self.id
480    }
481
482    /// Get a reference to the network node handle's peer id.
483    #[must_use]
484    pub fn peer_id(&self) -> PeerId {
485        self.peer_id
486    }
487
488    /// Return a reference to the network config
489    #[must_use]
490    pub fn config(&self) -> &NetworkNodeConfig {
491        &self.network_config
492    }
493}