hotshot_libp2p_networking/network/node/
handle.rs1use std::{collections::HashSet, fmt::Debug, sync::Arc, time::Duration};
8
9use bimap::BiMap;
10use hotshot_types::traits::{
11 network::NetworkError, node_implementation::NodeType, signature_key::SignatureKey,
12};
13use libp2p::{Multiaddr, request_response::ResponseChannel};
14use libp2p_identity::PeerId;
15use parking_lot::Mutex;
16use tokio::{
17 sync::mpsc::{Receiver, UnboundedReceiver, UnboundedSender},
18 time::{sleep, timeout},
19};
20use tracing::{debug, info, instrument};
21
22use crate::network::{
23 ClientRequest, NetworkEvent, NetworkNode, NetworkNodeConfig,
24 behaviours::dht::{
25 record::{Namespace, RecordKey, RecordValue},
26 store::persistent::DhtPersistentStorage,
27 },
28 gen_multiaddr, log_summary,
29};
30
31#[derive(Debug, Clone)]
35pub struct NetworkNodeHandle<T: NodeType> {
36 network_config: NetworkNodeConfig,
38
39 send_network: UnboundedSender<ClientRequest>,
41
42 consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
44
45 listen_addr: Multiaddr,
47
48 peer_id: PeerId,
50
51 id: usize,
53}
54
55#[derive(Debug)]
57pub struct NetworkNodeReceiver {
58 receiver: UnboundedReceiver<NetworkEvent>,
60
61 recv_kill: Option<Receiver<()>>,
63}
64
65impl NetworkNodeReceiver {
66 pub async fn recv(&mut self) -> Result<NetworkEvent, NetworkError> {
70 self.receiver
71 .recv()
72 .await
73 .ok_or(NetworkError::ChannelReceiveError(
74 "Receiver channel closed".to_string(),
75 ))
76 }
77 pub fn set_kill_switch(&mut self, kill_switch: Receiver<()>) {
79 self.recv_kill = Some(kill_switch);
80 }
81
82 pub fn take_kill_switch(&mut self) -> Option<Receiver<()>> {
84 self.recv_kill.take()
85 }
86}
87
88pub async fn spawn_network_node<T: NodeType, D: DhtPersistentStorage>(
92 config: NetworkNodeConfig,
93 dht_persistent_storage: D,
94 consensus_key_to_pid_map: Arc<Mutex<BiMap<T::SignatureKey, PeerId>>>,
95 id: usize,
96) -> Result<(NetworkNodeReceiver, NetworkNodeHandle<T>), NetworkError> {
97 let mut network: NetworkNode<T, _> = NetworkNode::new(
98 config.clone(),
99 dht_persistent_storage,
100 Arc::clone(&consensus_key_to_pid_map),
101 )
102 .await
103 .map_err(|e| NetworkError::ConfigError(format!("failed to create network node: {e}")))?;
104 let listen_addr = config
106 .bind_address
107 .clone()
108 .unwrap_or_else(|| gen_multiaddr(0));
109 let peer_id = network.peer_id();
110 let listen_addr = network.start_listen(listen_addr).await.map_err(|e| {
111 NetworkError::ListenError(format!("failed to start listening on Libp2p: {e}"))
112 })?;
113 let (send_chan, recv_chan) = network.spawn_listeners().map_err(|err| {
116 NetworkError::ListenError(format!("failed to spawn listeners for Libp2p: {err}"))
117 })?;
118 log_summary::spawn_summary_task();
119 let receiver = NetworkNodeReceiver {
120 receiver: recv_chan,
121 recv_kill: None,
122 };
123
124 let handle = NetworkNodeHandle::<T> {
125 network_config: config,
126 send_network: send_chan,
127 consensus_key_to_pid_map,
128 listen_addr,
129 peer_id,
130 id,
131 };
132 Ok((receiver, handle))
133}
134
135impl<T: NodeType> NetworkNodeHandle<T> {
136 #[instrument]
140 pub async fn shutdown(&self) -> Result<(), NetworkError> {
141 self.send_request(ClientRequest::Shutdown)?;
142 Ok(())
143 }
144 pub fn begin_bootstrap(&self) -> Result<(), NetworkError> {
149 let req = ClientRequest::BeginBootstrap;
150 self.send_request(req)
151 }
152
153 #[must_use]
155 pub fn listen_addr(&self) -> Multiaddr {
156 self.listen_addr.clone()
157 }
158
159 pub async fn print_routing_table(&self) -> Result<(), NetworkError> {
164 let (s, r) = futures::channel::oneshot::channel();
165 let req = ClientRequest::GetRoutingTable(s);
166 self.send_request(req)?;
167 r.await
168 .map_err(|e| NetworkError::ChannelReceiveError(e.to_string()))
169 }
170 pub async fn wait_to_connect(
175 &self,
176 num_required_peers: usize,
177 node_id: usize,
178 ) -> Result<(), NetworkError> {
179 loop {
181 let num_connected = self.num_connected().await?;
183 if num_connected >= num_required_peers {
184 break;
185 }
186
187 info!(
189 "Node {} connected to {}/{} peers",
190 node_id, num_connected, num_required_peers
191 );
192
193 sleep(Duration::from_secs(1)).await;
195 }
196
197 Ok(())
198 }
199
200 pub async fn lookup_pid(&self, peer_id: PeerId) -> Result<(), NetworkError> {
205 let (s, r) = futures::channel::oneshot::channel();
206 let req = ClientRequest::LookupPeer(peer_id, s);
207 self.send_request(req)?;
208 r.await
209 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
210 }
211
212 pub async fn lookup_node(
217 &self,
218 consensus_key: &T::SignatureKey,
219 dht_timeout: Duration,
220 ) -> Result<PeerId, NetworkError> {
221 if let Some(pid) = self
223 .consensus_key_to_pid_map
224 .lock()
225 .get_by_left(consensus_key)
226 {
227 return Ok(*pid);
228 }
229
230 let key = RecordKey::new(Namespace::Lookup, consensus_key.to_bytes());
232
233 let pid = self.get_record_timeout(key, dht_timeout).await?;
235
236 PeerId::from_bytes(&pid).map_err(|err| NetworkError::FailedToDeserialize(err.to_string()))
237 }
238
239 pub async fn put_record(
243 &self,
244 key: RecordKey,
245 value: RecordValue<T::SignatureKey>,
246 ) -> Result<(), NetworkError> {
247 let key = key.to_bytes();
249
250 let value = bincode::serialize(&value)
252 .map_err(|e| NetworkError::FailedToSerialize(e.to_string()))?;
253
254 let (s, r) = futures::channel::oneshot::channel();
255 let req = ClientRequest::PutDHT {
256 key: key.clone(),
257 value,
258 notify: s,
259 };
260
261 self.send_request(req)?;
262
263 r.await.map_err(|_| NetworkError::RequestCancelled)
264 }
265
266 pub async fn get_record(
272 &self,
273 key: RecordKey,
274 retry_count: u8,
275 ) -> Result<Vec<u8>, NetworkError> {
276 let serialized_key = key.to_bytes();
278
279 let (s, r) = futures::channel::oneshot::channel();
280 let req = ClientRequest::GetDHT {
281 key: serialized_key.clone(),
282 notify: vec![s],
283 retry_count,
284 };
285 self.send_request(req)?;
286
287 let result = r.await.map_err(|_| NetworkError::RequestCancelled)?;
289
290 let record: RecordValue<T::SignatureKey> = bincode::deserialize(&result)
292 .map_err(|e| NetworkError::FailedToDeserialize(e.to_string()))?;
293
294 Ok(record.value().to_vec())
295 }
296
297 pub async fn get_record_timeout(
303 &self,
304 key: RecordKey,
305 timeout_duration: Duration,
306 ) -> Result<Vec<u8>, NetworkError> {
307 timeout(timeout_duration, self.get_record(key, 3))
308 .await
309 .map_err(|err| NetworkError::Timeout(err.to_string()))?
310 }
311
312 pub async fn put_record_timeout(
318 &self,
319 key: RecordKey,
320 value: RecordValue<T::SignatureKey>,
321 timeout_duration: Duration,
322 ) -> Result<(), NetworkError> {
323 timeout(timeout_duration, self.put_record(key, value))
324 .await
325 .map_err(|err| NetworkError::Timeout(err.to_string()))?
326 }
327
328 pub async fn subscribe(&self, topic: String) -> Result<(), NetworkError> {
332 let (s, r) = futures::channel::oneshot::channel();
333 let req = ClientRequest::Subscribe(topic, Some(s));
334 self.send_request(req)?;
335 r.await
336 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
337 }
338
339 pub async fn unsubscribe(&self, topic: String) -> Result<(), NetworkError> {
343 let (s, r) = futures::channel::oneshot::channel();
344 let req = ClientRequest::Unsubscribe(topic, Some(s));
345 self.send_request(req)?;
346 r.await
347 .map_err(|err| NetworkError::ChannelReceiveError(err.to_string()))
348 }
349
350 pub fn ignore_peers(&self, peers: Vec<PeerId>) -> Result<(), NetworkError> {
355 let req = ClientRequest::IgnorePeers(peers);
356 self.send_request(req)
357 }
358
359 pub fn direct_request(&self, pid: PeerId, msg: &[u8]) -> Result<(), NetworkError> {
364 self.direct_request_no_serialize(pid, msg.to_vec())
365 }
366
367 pub fn direct_request_no_serialize(
372 &self,
373 pid: PeerId,
374 contents: Vec<u8>,
375 ) -> Result<(), NetworkError> {
376 let req = ClientRequest::DirectRequest {
377 pid,
378 contents,
379 retry_count: 1,
380 };
381 self.send_request(req)
382 }
383
384 pub fn direct_response(
389 &self,
390 chan: ResponseChannel<Vec<u8>>,
391 msg: &[u8],
392 ) -> Result<(), NetworkError> {
393 let req = ClientRequest::DirectResponse(chan, msg.to_vec());
394 self.send_request(req)
395 }
396
397 pub fn prune_peer(&self, pid: PeerId) -> Result<(), NetworkError> {
405 let req = ClientRequest::Prune(pid);
406 self.send_request(req)
407 }
408
409 pub fn gossip(&self, topic: String, msg: &[u8]) -> Result<(), NetworkError> {
414 self.gossip_no_serialize(topic, msg.to_vec())
415 }
416
417 pub fn gossip_no_serialize(&self, topic: String, msg: Vec<u8>) -> Result<(), NetworkError> {
422 let req = ClientRequest::GossipMsg(topic, msg);
423 self.send_request(req)
424 }
425
426 pub fn add_known_peers(
430 &self,
431 known_peers: Vec<(PeerId, Multiaddr)>,
432 ) -> Result<(), NetworkError> {
433 debug!("Adding {} known peers", known_peers.len());
434 let req = ClientRequest::AddKnownPeers(known_peers);
435 self.send_request(req)
436 }
437
438 fn send_request(&self, req: ClientRequest) -> Result<(), NetworkError> {
443 self.send_network
444 .send(req)
445 .map_err(|err| NetworkError::ChannelSendError(err.to_string()))
446 }
447
448 pub async fn num_connected(&self) -> Result<usize, NetworkError> {
456 let (s, r) = futures::channel::oneshot::channel();
457 let req = ClientRequest::GetConnectedPeerNum(s);
458 self.send_request(req)?;
459 Ok(r.await.unwrap())
460 }
461
462 pub async fn connected_pids(&self) -> Result<HashSet<PeerId>, NetworkError> {
470 let (s, r) = futures::channel::oneshot::channel();
471 let req = ClientRequest::GetConnectedPeers(s);
472 self.send_request(req)?;
473 Ok(r.await.unwrap())
474 }
475
476 #[must_use]
478 pub fn id(&self) -> usize {
479 self.id
480 }
481
482 #[must_use]
484 pub fn peer_id(&self) -> PeerId {
485 self.peer_id
486 }
487
488 #[must_use]
490 pub fn config(&self) -> &NetworkNodeConfig {
491 &self.network_config
492 }
493}