1pub mod peer;
2pub mod server;
3
4use std::{
5 collections::HashMap,
6 fmt,
7 ops::{Deref, DerefMut},
8 sync::Arc,
9};
10
11use bon::Builder;
12use bytes::Bytes;
13use tokio::{
14 net::TcpListener,
15 sync::{
16 OwnedSemaphorePermit,
17 mpsc::{self, UnboundedReceiver, UnboundedSender},
18 oneshot, watch,
19 },
20 task::JoinHandle,
21};
22use tracing::{debug, info, warn};
23
24use crate::{
25 Config, Metrics, Role, addr::NetAddr, error::NetworkError, metrics::NoMetrics, msg::Slot,
26 net::server::Server, x25519::PublicKey,
27};
28
29type PeerMessage = (PublicKey, Bytes, Option<OwnedSemaphorePermit>);
30
31#[derive(Debug)]
32pub struct Network {
33 recv: NetworkReceiver,
34 ctrl: NetworkController,
35}
36
37#[derive(Debug)]
38pub struct NetworkReceiver {
39 rx: UnboundedReceiver<PeerMessage>,
40}
41
42pub struct NetworkController {
43 conf: Arc<Config>,
44 node: PublicKey,
45 parties: HashMap<PublicKey, Role>,
46 tx: UnboundedSender<Command>,
47 next_slot: watch::Sender<Slot>,
48 lower_bound: Slot,
49 task: JoinHandle<()>,
50 metrics: Arc<dyn Metrics>,
51}
52
53impl fmt::Debug for NetworkController {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("NetworkController")
56 .field("node", &self.node)
57 .field("lower_bound", &self.lower_bound)
58 .field("conf", &self.conf)
59 .finish()
60 }
61}
62
63#[derive(Debug)]
65enum Command {
66 Peer(PeerCommand),
67 Send(SendCommand),
68 Shutdown(oneshot::Sender<()>),
69}
70
71#[derive(Debug)]
73enum PeerCommand {
74 Add(Role, Vec<(PublicKey, NetAddr)>),
76 Remove(Vec<PublicKey>),
78 Assign(Role, Vec<PublicKey>),
80}
81
82#[derive(Clone, Debug, Builder)]
84pub struct SendCommand {
85 slot: Slot,
86 action: SendAction,
87 #[builder(default)]
88 retry: RetryPolicy,
89}
90
91#[derive(Clone, Copy, Debug, Default)]
93pub enum RetryPolicy {
94 #[default]
95 Default,
96 NoRetry,
97}
98
99impl RetryPolicy {
100 pub fn is_retry(self) -> bool {
101 matches!(self, Self::Default)
102 }
103}
104
105#[derive(Clone, Debug)]
106pub enum SendAction {
107 Unicast(PublicKey, Vec<u8>),
109 Multicast(Vec<PublicKey>, Vec<u8>),
111 Broadcast(Vec<u8>),
113}
114
115impl Network {
116 pub async fn create(conf: Config) -> Result<Self, NetworkError> {
117 let listener = TcpListener::bind(conf.bind.to_string())
118 .await
119 .map_err(|e| NetworkError::Bind(conf.bind.clone(), e))?;
120
121 let _addr = listener.local_addr()?;
122 let node = conf.keypair.public_key();
123 let parties = HashMap::from_iter(conf.parties.iter().map(|(k, _)| (*k, Role::Active)));
124
125 let (otx, orx) = mpsc::unbounded_channel();
127
128 let (itx, irx) = mpsc::unbounded_channel();
130
131 let (etx, erx) = watch::channel(Slot::MIN);
132
133 let metr = conf.metrics.clone().unwrap_or_else(|| Arc::new(NoMetrics));
134 let conf = Arc::new(conf);
135 let serv = Server::spawn(
136 conf.clone(),
137 listener,
138 Role::Active,
139 itx,
140 orx,
141 erx,
142 metr.clone(),
143 );
144 let recv = NetworkReceiver { rx: irx };
145 let ctrl = NetworkController {
146 conf: conf.clone(),
147 node,
148 parties,
149 tx: otx,
150 task: serv,
151 next_slot: etx,
152 lower_bound: Slot::MIN,
153 metrics: metr,
154 };
155
156 info!(name = %conf.name, %node, addr = %_addr, "listening");
157
158 Ok(Self { recv, ctrl })
159 }
160
161 pub fn controller(&mut self) -> &mut NetworkController {
162 &mut self.ctrl
163 }
164
165 pub fn receiver(&mut self) -> &mut NetworkReceiver {
166 &mut self.recv
167 }
168
169 pub fn split(&self) -> (&NetworkController, &NetworkReceiver) {
170 (&self.ctrl, &self.recv)
171 }
172
173 pub fn split_mut(&mut self) -> (&mut NetworkController, &mut NetworkReceiver) {
174 (&mut self.ctrl, &mut self.recv)
175 }
176
177 pub fn split_into(self) -> (NetworkController, NetworkReceiver) {
178 (self.ctrl, self.recv)
179 }
180
181 pub async fn receive(&mut self) -> Option<(PublicKey, Bytes)> {
182 self.recv.receive().await
183 }
184}
185
186impl Deref for Network {
187 type Target = NetworkController;
188
189 fn deref(&self) -> &Self::Target {
190 &self.ctrl
191 }
192}
193
194impl DerefMut for Network {
195 fn deref_mut(&mut self) -> &mut Self::Target {
196 &mut self.ctrl
197 }
198}
199
200impl NetworkReceiver {
201 pub async fn receive(&mut self) -> Option<(PublicKey, Bytes)> {
205 let (k, b, _) = self.rx.recv().await?;
206 debug!(peer = %k, len = b.len(), "message received");
207 Some((k, b))
208 }
209}
210
211impl NetworkController {
212 pub fn config(&self) -> &Config {
213 &self.conf
214 }
215
216 pub fn parties(&self) -> impl Iterator<Item = (&PublicKey, &Role)> {
218 self.parties.iter()
219 }
220
221 pub fn unicast(&mut self, s: Slot, to: PublicKey, msg: Vec<u8>) -> Result<(), NetworkError> {
223 debug!(slot = %s, %to, len = msg.len(), "unicast");
224 self.length_check(&msg)?;
225 if self.lt_lower_bound(s) {
226 return Ok(());
227 }
228 let cmd = SendCommand::builder()
229 .slot(s)
230 .action(SendAction::Unicast(to, msg))
231 .build();
232 self.tx
233 .send(Command::Send(cmd))
234 .map_err(|_| NetworkError::ChannelClosed)
235 }
236
237 pub fn broadcast(&mut self, s: Slot, msg: Vec<u8>) -> Result<(), NetworkError> {
239 debug!(slot = %s, len = msg.len(), "broadcast");
240 self.length_check(&msg)?;
241 if self.lt_lower_bound(s) {
242 return Ok(());
243 }
244 let cmd = SendCommand::builder()
245 .slot(s)
246 .action(SendAction::Broadcast(msg))
247 .build();
248 self.tx
249 .send(Command::Send(cmd))
250 .map_err(|_| NetworkError::ChannelClosed)
251 }
252
253 pub fn multicast<P>(&mut self, s: Slot, to: P, msg: Vec<u8>) -> Result<(), NetworkError>
255 where
256 P: IntoIterator<Item = PublicKey>,
257 {
258 debug!(slot = %s, len = msg.len(), "multicast");
259 self.length_check(&msg)?;
260 if self.lt_lower_bound(s) {
261 return Ok(());
262 }
263 let cmd = SendCommand::builder()
264 .slot(s)
265 .action(SendAction::Multicast(to.into_iter().collect(), msg))
266 .build();
267 self.tx
268 .send(Command::Send(cmd))
269 .map_err(|_| NetworkError::ChannelClosed)
270 }
271
272 pub fn send(&mut self, cmd: SendCommand) -> Result<(), NetworkError> {
274 let bytes = msg_bytes(&cmd);
275 debug!(slot = %cmd.slot, len = %bytes.len(), "send");
276 self.length_check(bytes)?;
277 if self.lt_lower_bound(cmd.slot) {
278 return Ok(());
279 }
280 self.tx
281 .send(Command::Send(cmd))
282 .map_err(|_| NetworkError::ChannelClosed)
283 }
284
285 pub fn add_peers<P>(&mut self, r: Role, peers: P) -> Result<(), NetworkError>
287 where
288 P: IntoIterator<Item = (PublicKey, NetAddr)>,
289 {
290 debug!(role = %r, "add_peers");
291 let peers = peers.into_iter().collect::<Vec<_>>();
292 self.parties.extend(peers.iter().map(|(p, ..)| (*p, r)));
293 self.tx
294 .send(Command::Peer(PeerCommand::Add(r, peers)))
295 .map_err(|_| NetworkError::ChannelClosed)
296 }
297
298 pub fn remove_peers<P>(&mut self, peers: P) -> Result<(), NetworkError>
300 where
301 P: IntoIterator<Item = PublicKey>,
302 {
303 debug!("remove_peers");
304 let peers = peers.into_iter().collect::<Vec<_>>();
305 for p in &peers {
306 self.parties.remove(p);
307 self.metrics.del(p);
308 }
309 self.tx
310 .send(Command::Peer(PeerCommand::Remove(peers)))
311 .map_err(|_| NetworkError::ChannelClosed)
312 }
313
314 pub fn assign_peers<P>(&mut self, r: Role, peers: P) -> Result<(), NetworkError>
316 where
317 P: IntoIterator<Item = PublicKey>,
318 {
319 debug!(role = %r, "assign_peers");
320 let peers = peers.into_iter().collect::<Vec<_>>();
321 for p in &peers {
322 if let Some(role) = self.parties.get_mut(p) {
323 *role = r
324 }
325 }
326 self.tx
327 .send(Command::Peer(PeerCommand::Assign(r, peers)))
328 .map_err(|_| NetworkError::ChannelClosed)
329 }
330
331 pub fn gc(&mut self, s: Slot) -> Result<(), NetworkError> {
332 debug!(slot = %s, "gc");
333 if s <= self.lower_bound {
334 return Ok(());
335 }
336 self.next_slot
337 .send(s)
338 .map_err(|_| NetworkError::ChannelClosed)?;
339 self.lower_bound = s;
340 Ok(())
341 }
342
343 pub fn shutdown(&mut self) -> Result<impl Future<Output = ()> + use<>, NetworkError> {
347 debug!("shutdown");
348 let (tx, rx) = oneshot::channel();
349 self.tx
350 .send(Command::Shutdown(tx))
351 .map_err(|_| NetworkError::ChannelClosed)?;
352 Ok(async move {
353 let _ = rx.await;
354 })
355 }
356
357 fn length_check(&self, msg: &[u8]) -> Result<(), NetworkError> {
359 if msg.len() > self.conf.max_message_size.get() {
360 warn!(
361 name = %self.conf.name,
362 node = %self.node,
363 len = %msg.len(),
364 max = %self.conf.max_message_size,
365 "message too large to send"
366 );
367 return Err(NetworkError::MessageTooLarge);
368 }
369 Ok(())
370 }
371
372 fn lt_lower_bound(&self, s: Slot) -> bool {
374 if s < self.lower_bound {
375 warn!(
376 name = %self.conf.name,
377 node = %self.node,
378 slot = %s,
379 lower_bound = %self.lower_bound,
380 "slot below lower bound"
381 );
382 return true;
383 }
384 false
385 }
386}
387
388fn msg_bytes(cmd: &SendCommand) -> &[u8] {
389 match &cmd.action {
390 SendAction::Unicast(_, b) => b,
391 SendAction::Multicast(_, b) => b,
392 SendAction::Broadcast(b) => b,
393 }
394}
395
396impl Drop for NetworkController {
397 fn drop(&mut self) {
398 self.task.abort();
399 }
400}