Skip to main content

cliquenet/
net.rs

1pub mod peer;
2pub mod server;
3
4use std::{
5    collections::HashMap,
6    fmt,
7    ops::{Deref, DerefMut},
8    sync::Arc,
9};
10
11use bon::Builder;
12use bytes::Bytes;
13use tokio::{
14    net::TcpListener,
15    sync::{
16        OwnedSemaphorePermit,
17        mpsc::{self, UnboundedReceiver, UnboundedSender},
18        oneshot, watch,
19    },
20    task::JoinHandle,
21};
22use tracing::{debug, info, warn};
23
24use crate::{
25    Config, Metrics, Role, addr::NetAddr, error::NetworkError, metrics::NoMetrics, msg::Slot,
26    net::server::Server, x25519::PublicKey,
27};
28
29type PeerMessage = (PublicKey, Bytes, Option<OwnedSemaphorePermit>);
30
31#[derive(Debug)]
32pub struct Network {
33    recv: NetworkReceiver,
34    ctrl: NetworkController,
35}
36
37#[derive(Debug)]
38pub struct NetworkReceiver {
39    rx: UnboundedReceiver<PeerMessage>,
40}
41
42pub struct NetworkController {
43    conf: Arc<Config>,
44    node: PublicKey,
45    parties: HashMap<PublicKey, Role>,
46    tx: UnboundedSender<Command>,
47    next_slot: watch::Sender<Slot>,
48    lower_bound: Slot,
49    task: JoinHandle<()>,
50    metrics: Arc<dyn Metrics>,
51}
52
53impl fmt::Debug for NetworkController {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        f.debug_struct("NetworkController")
56            .field("node", &self.node)
57            .field("lower_bound", &self.lower_bound)
58            .field("conf", &self.conf)
59            .finish()
60    }
61}
62
63/// Server task instructions.
64#[derive(Debug)]
65enum Command {
66    Peer(PeerCommand),
67    Send(SendCommand),
68    Shutdown(oneshot::Sender<()>),
69}
70
71/// Update network peers.
72#[derive(Debug)]
73enum PeerCommand {
74    /// Add the given peers.
75    Add(Role, Vec<(PublicKey, NetAddr)>),
76    /// Remove the given peers.
77    Remove(Vec<PublicKey>),
78    /// Assign a `Role` to the given peers.
79    Assign(Role, Vec<PublicKey>),
80}
81
82/// Send to peer(s).
83#[derive(Clone, Debug, Builder)]
84pub struct SendCommand {
85    slot: Slot,
86    action: SendAction,
87    #[builder(default)]
88    retry: RetryPolicy,
89}
90
91/// Specify if a message should be retried if no ACK is received.
92#[derive(Clone, Copy, Debug, Default)]
93pub enum RetryPolicy {
94    #[default]
95    Default,
96    NoRetry,
97}
98
99impl RetryPolicy {
100    pub fn is_retry(self) -> bool {
101        matches!(self, Self::Default)
102    }
103}
104
105#[derive(Clone, Debug)]
106pub enum SendAction {
107    /// Send a message to one peer.
108    Unicast(PublicKey, Vec<u8>),
109    /// Send a message to some peers.
110    Multicast(Vec<PublicKey>, Vec<u8>),
111    /// Send a message to all peers with `Role::Active`.
112    Broadcast(Vec<u8>),
113}
114
115impl Network {
116    pub async fn create(conf: Config) -> Result<Self, NetworkError> {
117        let listener = TcpListener::bind(conf.bind.to_string())
118            .await
119            .map_err(|e| NetworkError::Bind(conf.bind.clone(), e))?;
120
121        let _addr = listener.local_addr()?;
122        let node = conf.keypair.public_key();
123        let parties = HashMap::from_iter(conf.parties.iter().map(|(k, _)| (*k, Role::Active)));
124
125        // Command channel from application to network.
126        let (otx, orx) = mpsc::unbounded_channel();
127
128        // Channel of messages from peers to the application.
129        let (itx, irx) = mpsc::unbounded_channel();
130
131        let (etx, erx) = watch::channel(Slot::MIN);
132
133        let metr = conf.metrics.clone().unwrap_or_else(|| Arc::new(NoMetrics));
134        let conf = Arc::new(conf);
135        let serv = Server::spawn(
136            conf.clone(),
137            listener,
138            Role::Active,
139            itx,
140            orx,
141            erx,
142            metr.clone(),
143        );
144        let recv = NetworkReceiver { rx: irx };
145        let ctrl = NetworkController {
146            conf: conf.clone(),
147            node,
148            parties,
149            tx: otx,
150            task: serv,
151            next_slot: etx,
152            lower_bound: Slot::MIN,
153            metrics: metr,
154        };
155
156        info!(name = %conf.name, %node, addr = %_addr, "listening");
157
158        Ok(Self { recv, ctrl })
159    }
160
161    pub fn controller(&mut self) -> &mut NetworkController {
162        &mut self.ctrl
163    }
164
165    pub fn receiver(&mut self) -> &mut NetworkReceiver {
166        &mut self.recv
167    }
168
169    pub fn split(&self) -> (&NetworkController, &NetworkReceiver) {
170        (&self.ctrl, &self.recv)
171    }
172
173    pub fn split_mut(&mut self) -> (&mut NetworkController, &mut NetworkReceiver) {
174        (&mut self.ctrl, &mut self.recv)
175    }
176
177    pub fn split_into(self) -> (NetworkController, NetworkReceiver) {
178        (self.ctrl, self.recv)
179    }
180
181    pub async fn receive(&mut self) -> Option<(PublicKey, Bytes)> {
182        self.recv.receive().await
183    }
184}
185
186impl Deref for Network {
187    type Target = NetworkController;
188
189    fn deref(&self) -> &Self::Target {
190        &self.ctrl
191    }
192}
193
194impl DerefMut for Network {
195    fn deref_mut(&mut self) -> &mut Self::Target {
196        &mut self.ctrl
197    }
198}
199
200impl NetworkReceiver {
201    /// Receive the next incoming message.
202    ///
203    /// The returned public key denotes the source where the message came from.
204    pub async fn receive(&mut self) -> Option<(PublicKey, Bytes)> {
205        let (k, b, _) = self.rx.recv().await?;
206        debug!(peer = %k, len = b.len(), "message received");
207        Some((k, b))
208    }
209}
210
211impl NetworkController {
212    pub fn config(&self) -> &Config {
213        &self.conf
214    }
215
216    /// Iterate over all parties.
217    pub fn parties(&self) -> impl Iterator<Item = (&PublicKey, &Role)> {
218        self.parties.iter()
219    }
220
221    /// Send a message to a party, identified by the given public key.
222    pub fn unicast(&mut self, s: Slot, to: PublicKey, msg: Vec<u8>) -> Result<(), NetworkError> {
223        debug!(slot = %s, %to, len = msg.len(), "unicast");
224        self.length_check(&msg)?;
225        if self.lt_lower_bound(s) {
226            return Ok(());
227        }
228        let cmd = SendCommand::builder()
229            .slot(s)
230            .action(SendAction::Unicast(to, msg))
231            .build();
232        self.tx
233            .send(Command::Send(cmd))
234            .map_err(|_| NetworkError::ChannelClosed)
235    }
236
237    /// Send a message to all parties.
238    pub fn broadcast(&mut self, s: Slot, msg: Vec<u8>) -> Result<(), NetworkError> {
239        debug!(slot = %s, len = msg.len(), "broadcast");
240        self.length_check(&msg)?;
241        if self.lt_lower_bound(s) {
242            return Ok(());
243        }
244        let cmd = SendCommand::builder()
245            .slot(s)
246            .action(SendAction::Broadcast(msg))
247            .build();
248        self.tx
249            .send(Command::Send(cmd))
250            .map_err(|_| NetworkError::ChannelClosed)
251    }
252
253    /// Send a message to several parties, identified by their public keys.
254    pub fn multicast<P>(&mut self, s: Slot, to: P, msg: Vec<u8>) -> Result<(), NetworkError>
255    where
256        P: IntoIterator<Item = PublicKey>,
257    {
258        debug!(slot = %s, len = msg.len(), "multicast");
259        self.length_check(&msg)?;
260        if self.lt_lower_bound(s) {
261            return Ok(());
262        }
263        let cmd = SendCommand::builder()
264            .slot(s)
265            .action(SendAction::Multicast(to.into_iter().collect(), msg))
266            .build();
267        self.tx
268            .send(Command::Send(cmd))
269            .map_err(|_| NetworkError::ChannelClosed)
270    }
271
272    /// General send operation, supporting custom retry policies.
273    pub fn send(&mut self, cmd: SendCommand) -> Result<(), NetworkError> {
274        let bytes = msg_bytes(&cmd);
275        debug!(slot = %cmd.slot, len = %bytes.len(), "send");
276        self.length_check(bytes)?;
277        if self.lt_lower_bound(cmd.slot) {
278            return Ok(());
279        }
280        self.tx
281            .send(Command::Send(cmd))
282            .map_err(|_| NetworkError::ChannelClosed)
283    }
284
285    /// Add the given peers to the network.
286    pub fn add_peers<P>(&mut self, r: Role, peers: P) -> Result<(), NetworkError>
287    where
288        P: IntoIterator<Item = (PublicKey, NetAddr)>,
289    {
290        debug!(role = %r, "add_peers");
291        let peers = peers.into_iter().collect::<Vec<_>>();
292        self.parties.extend(peers.iter().map(|(p, ..)| (*p, r)));
293        self.tx
294            .send(Command::Peer(PeerCommand::Add(r, peers)))
295            .map_err(|_| NetworkError::ChannelClosed)
296    }
297
298    /// Remove the given peers from the network.
299    pub fn remove_peers<P>(&mut self, peers: P) -> Result<(), NetworkError>
300    where
301        P: IntoIterator<Item = PublicKey>,
302    {
303        debug!("remove_peers");
304        let peers = peers.into_iter().collect::<Vec<_>>();
305        for p in &peers {
306            self.parties.remove(p);
307            self.metrics.del(p);
308        }
309        self.tx
310            .send(Command::Peer(PeerCommand::Remove(peers)))
311            .map_err(|_| NetworkError::ChannelClosed)
312    }
313
314    /// Assign the given role to the given peers.
315    pub fn assign_peers<P>(&mut self, r: Role, peers: P) -> Result<(), NetworkError>
316    where
317        P: IntoIterator<Item = PublicKey>,
318    {
319        debug!(role = %r, "assign_peers");
320        let peers = peers.into_iter().collect::<Vec<_>>();
321        for p in &peers {
322            if let Some(role) = self.parties.get_mut(p) {
323                *role = r
324            }
325        }
326        self.tx
327            .send(Command::Peer(PeerCommand::Assign(r, peers)))
328            .map_err(|_| NetworkError::ChannelClosed)
329    }
330
331    pub fn gc(&mut self, s: Slot) -> Result<(), NetworkError> {
332        debug!(slot = %s, "gc");
333        if s <= self.lower_bound {
334            return Ok(());
335        }
336        self.next_slot
337            .send(s)
338            .map_err(|_| NetworkError::ChannelClosed)?;
339        self.lower_bound = s;
340        Ok(())
341    }
342
343    /// Trigger network shutdown.
344    ///
345    /// The returned future will resolve once the server task finished.
346    pub fn shutdown(&mut self) -> Result<impl Future<Output = ()> + use<>, NetworkError> {
347        debug!("shutdown");
348        let (tx, rx) = oneshot::channel();
349        self.tx
350            .send(Command::Shutdown(tx))
351            .map_err(|_| NetworkError::ChannelClosed)?;
352        Ok(async move {
353            let _ = rx.await;
354        })
355    }
356
357    /// Check the number of message bytes does not exceed the configured maximum.
358    fn length_check(&self, msg: &[u8]) -> Result<(), NetworkError> {
359        if msg.len() > self.conf.max_message_size.get() {
360            warn!(
361                name = %self.conf.name,
362                node = %self.node,
363                len  = %msg.len(),
364                max  = %self.conf.max_message_size,
365                "message too large to send"
366            );
367            return Err(NetworkError::MessageTooLarge);
368        }
369        Ok(())
370    }
371
372    /// Check if the given slot is less than our lower bound.
373    fn lt_lower_bound(&self, s: Slot) -> bool {
374        if s < self.lower_bound {
375            warn!(
376                name = %self.conf.name,
377                node = %self.node,
378                slot = %s,
379                lower_bound = %self.lower_bound,
380                "slot below lower bound"
381            );
382            return true;
383        }
384        false
385    }
386}
387
388fn msg_bytes(cmd: &SendCommand) -> &[u8] {
389    match &cmd.action {
390        SendAction::Unicast(_, b) => b,
391        SendAction::Multicast(_, b) => b,
392        SendAction::Broadcast(b) => b,
393    }
394}
395
396impl Drop for NetworkController {
397    fn drop(&mut self) {
398        self.task.abort();
399    }
400}