Skip to main content

cliquenet/
connection.rs

1use std::{
2    cmp::min,
3    io,
4    iter::{once, repeat},
5    net::SocketAddr,
6    sync::Arc,
7    time::Duration,
8};
9
10use rand::RngExt;
11use snow::{Builder, HandshakeState, TransportState, params::NoiseParams};
12use tokio::{
13    io::{AsyncReadExt, AsyncWriteExt},
14    net::TcpStream,
15    time::sleep,
16    try_join,
17};
18use tracing::{debug, warn};
19
20use crate::{
21    Config, Version,
22    addr::NetAddr,
23    error::NetworkError,
24    msg::{Header, MAX_NOISE_MESSAGE_SIZE, hello::Hello},
25    util::until,
26    x25519::PublicKey,
27};
28
29const MAX_NOISE_HANDSHAKE_SIZE: usize = 1024;
30
31type Result<T> = std::result::Result<T, NetworkError>;
32
33pub struct Connection {
34    pub key: PublicKey,
35    pub addr: SocketAddr,
36    pub stream: TcpStream,
37    pub state: TransportState,
38}
39
40type Prologue = Vec<u8>;
41
42impl Connection {
43    pub async fn accept(conf: Arc<Config>, mut stream: TcpStream) -> Result<Self> {
44        let node = conf.keypair.public_key();
45
46        if let Err(err) = stream.set_nodelay(true) {
47            warn!(name = %conf.name, %node, %err, "failed to enable NO_DELAY option")
48        }
49
50        until(conf.handshake_timeout, async move {
51            let (version, prologue) = select_version(&conf, &mut stream, false).await?;
52
53            debug!(name = %conf.name, %node, %version, "negotiated version");
54
55            let noise_proto = conf
56                .noise_protocols
57                .get(&version)
58                .expect("selected version has noise config");
59
60            let hs = Builder::new(noise_proto.noise_params())
61                .local_private_key(&conf.keypair.secret_key().as_bytes())
62                .expect("valid private key")
63                .prologue(&prologue)
64                .expect("1st time we set the prologue")
65                .build_responder()
66                .expect("valid noise params yield valid handshake state");
67
68            let addr = stream.peer_addr()?;
69            let state = on_handshake(&mut stream, hs).await?;
70
71            if let Some(key) = remote_static_key(&state) {
72                Ok(Self {
73                    key,
74                    addr,
75                    stream,
76                    state,
77                })
78            } else {
79                warn!(name = %conf.name, %node, %addr, "invalid static key");
80                Err(NetworkError::InvalidHandshakeMessage)
81            }
82        })
83        .await
84    }
85
86    pub async fn connect(conf: Arc<Config>, peer: PublicKey, addr: NetAddr) -> Self {
87        let mut delays = once({
88            if conf.random_connect_delay {
89                Duration::from_millis(rand::rng().random_range(0..1000))
90            } else {
91                Duration::ZERO
92            }
93        })
94        .chain(
95            conf.connect_retry_delays
96                .iter()
97                .map(|&d| Duration::from_secs(d.into())),
98        )
99        .chain(repeat({
100            let d = *conf.connect_retry_delays.last();
101            Duration::from_secs(d.into())
102        }));
103
104        let addr = addr.to_string();
105        let node = conf.keypair.public_key();
106
107        let mut backoff = None;
108
109        loop {
110            if let Some(d) = backoff.take() {
111                sleep(d).await;
112            } else {
113                sleep(delays.next().expect("delays iterator is infinite")).await;
114            }
115
116            debug!(name = %conf.name, %node, %peer, %addr, "connecting");
117
118            match try_connect(&conf, &peer, &addr).await {
119                Ok(mut conn) => {
120                    let hello_exchange = until(conf.handshake_timeout, async {
121                        conn.send_hello(Hello::Ok).await?;
122                        conn.recv_hello().await
123                    });
124                    match hello_exchange.await {
125                        Ok(h) if h.is_ok() => break conn,
126                        Ok(h) => {
127                            warn!(
128                                name = %conf.name,
129                                %node,
130                                %peer,
131                                remote = %conn.key,
132                                %addr,
133                                "hello response was not ok"
134                            );
135                            backoff = h.backoff_duration();
136                        },
137                        Err(err) => {
138                            warn!(
139                                name = %conf.name,
140                                %node,
141                                %peer,
142                                remote = %conn.key,
143                                %addr,
144                                %err,
145                                "failed to exchange hello"
146                            );
147                        },
148                    }
149                },
150                Err(err) => {
151                    warn!(name = %conf.name, %node, %peer, %addr, %err, "connect/handshake error")
152                },
153            }
154        }
155    }
156
157    /// Send a `Hello` frame.
158    pub async fn send_hello(&mut self, h: Hello) -> Result<()> {
159        let mut b = [0u8; 64];
160        let n = self
161            .state
162            .write_message(h.to_bytes().as_ref(), &mut b[Header::SIZE..])?;
163        let h = Header::data(n as u16);
164        send_frame(&mut self.stream, h, &mut b[..Header::SIZE + n]).await?;
165        Ok(())
166    }
167
168    /// Read a `Hello` frame.
169    pub async fn recv_hello(&mut self) -> Result<Hello> {
170        let mut a = [0u8; 64];
171        let h = recv_frame(&mut self.stream, &mut a).await?;
172        let mut b = [0u8; 64];
173        let n = self.state.read_message(&a[..h.len().into()], &mut b)?;
174        let h = Hello::from_bytes(&b[..n]).ok_or(NetworkError::InvalidHello)?;
175        Ok(h)
176    }
177}
178
179async fn try_connect(conf: &Config, peer: &PublicKey, addr: &str) -> Result<Connection> {
180    let new_handshake_state = |prologue: &Prologue, params: NoiseParams| {
181        Builder::new(params)
182            .local_private_key(conf.keypair.secret_key().as_slice())
183            .expect("valid private key")
184            .remote_public_key(peer.as_slice())
185            .expect("valid remote pub key")
186            .prologue(prologue)
187            .expect("1st time we set the prologue")
188            .build_initiator()
189            .expect("valid noise params yield valid handshake state")
190    };
191
192    let mut stream = until(conf.connect_timeout, TcpStream::connect(addr)).await?;
193
194    let node = conf.keypair.public_key();
195    let addr = stream.peer_addr()?;
196
197    debug!(name = %conf.name, %node, %peer, %addr, "tcp connection established");
198
199    if let Err(err) = stream.set_nodelay(true) {
200        warn!(name = %conf.name, %node, %err, "failed to enable NO_DELAY option");
201    }
202
203    until(conf.handshake_timeout, async move {
204        let (version, prologue) = select_version(conf, &mut stream, true).await?;
205
206        debug!(name = %conf.name, %node, %peer, %addr, %version, "negotiated version");
207
208        let noise_proto = conf
209            .noise_protocols
210            .get(&version)
211            .expect("selected version has noise config");
212
213        let hshake = new_handshake_state(&prologue, noise_proto.noise_params());
214        let state = handshake(&mut stream, hshake).await?;
215        match remote_static_key(&state) {
216            Some(key) if key == *peer => Ok(Connection {
217                key,
218                addr,
219                stream,
220                state,
221            }),
222            Some(key) => {
223                warn!(name = %conf.name, %node, %peer, remote = %key, %addr, "static key mismatch");
224                Err(NetworkError::InvalidHandshakeMessage)
225            },
226            None => {
227                warn!(name = %conf.name, %node, %peer, %addr, "invalid static key");
228                Err(NetworkError::InvalidHandshakeMessage)
229            },
230        }
231    })
232    .await
233}
234
235fn remote_static_key(state: &TransportState) -> Option<PublicKey> {
236    let k = state.get_remote_static()?;
237    PublicKey::try_from(k).ok()
238}
239
240/// Select a version from the range that both sides support.
241///
242/// This will be the minimum of the max. supported ones from both sides.
243async fn select_version(
244    conf: &Config,
245    stream: &mut TcpStream,
246    is_initiator: bool,
247) -> Result<(Version, Prologue)> {
248    const INIT_PAYLOAD_LEN: usize = 4;
249
250    let our_min = *conf.noise_protocols.first().0;
251    let our_max = *conf.noise_protocols.last().0;
252
253    let mut send_buf = [0u8; INIT_PAYLOAD_LEN];
254    let mut recv_buf = [0u8; INIT_PAYLOAD_LEN];
255
256    send_buf[0..2].copy_from_slice(&u16::from(our_min).to_be_bytes());
257    send_buf[2..4].copy_from_slice(&u16::from(our_max).to_be_bytes());
258
259    let (mut r, mut w) = stream.split();
260    try_join!(w.write_all(&send_buf), r.read_exact(&mut recv_buf))?;
261
262    let their_min = Version::from(u16::from_be_bytes([recv_buf[0], recv_buf[1]]));
263    let their_max = Version::from(u16::from_be_bytes([recv_buf[2], recv_buf[3]]));
264
265    let selected = min(our_max, their_max);
266
267    if selected < their_min || selected < our_min {
268        return Err(NetworkError::IncompatibleVersions {
269            ours: (our_min, our_max),
270            theirs: (their_min, their_max),
271        });
272    }
273
274    // Construct the prologue so that both sides end up with the same value.
275    // We include the sent and received version ranges to ensure no one has
276    // tampered with those values as they were sent in plain text.
277    let mut prologue = Vec::new();
278    prologue.extend_from_slice(conf.name.as_bytes());
279    if is_initiator {
280        prologue.extend_from_slice(&send_buf);
281        prologue.extend_from_slice(&recv_buf);
282    } else {
283        prologue.extend_from_slice(&recv_buf);
284        prologue.extend_from_slice(&send_buf);
285    }
286
287    Ok((selected, prologue))
288}
289
290/// Perform a noise handshake as initiator with the remote party.
291async fn handshake(stream: &mut TcpStream, mut hs: HandshakeState) -> Result<TransportState> {
292    let mut a = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
293    let n = hs.write_message(&[], &mut a[Header::SIZE..])?;
294    let h = Header::data(n as u16);
295    send_frame(stream, h, &mut a[..Header::SIZE + n]).await?;
296    let mut b = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
297    let h = recv_frame(stream, &mut b).await?;
298    if !h.is_data() || h.is_partial() {
299        return Err(NetworkError::InvalidHandshakeMessage);
300    }
301    hs.read_message(&b[..h.len().into()], &mut a)?;
302    Ok(hs.into_transport_mode()?)
303}
304
305/// Perform a noise handshake as responder with a remote party.
306async fn on_handshake(stream: &mut TcpStream, mut hs: HandshakeState) -> Result<TransportState> {
307    let mut a = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
308    let h = recv_frame(stream, &mut a).await?;
309    if !h.is_data() || h.is_partial() {
310        return Err(NetworkError::InvalidHandshakeMessage);
311    }
312    let mut b = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
313    hs.read_message(&a[..h.len().into()], &mut b)?;
314    let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
315    let h = Header::data(n as u16);
316    send_frame(stream, h, &mut b[..Header::SIZE + n]).await?;
317    Ok(hs.into_transport_mode()?)
318}
319
320/// Read a single frame (header + payload) from the remote.
321async fn recv_frame<R, const N: usize>(stream: &mut R, buf: &mut [u8; N]) -> io::Result<Header>
322where
323    R: AsyncReadExt + Unpin,
324{
325    let h = {
326        let n = stream.read_u32().await?;
327        Header::unvalidated(n)
328    };
329    let n = h.len().into();
330    if n > N {
331        return Err(io::ErrorKind::InvalidInput.into());
332    }
333    stream.read_exact(&mut buf[..n]).await?;
334    Ok(h)
335}
336
337/// Write a single frame (header + payload) to the remote.
338///
339/// The header is serialised into the first 4 bytes of `msg`. It is the
340/// caller's responsibility to ensure there is room at the beginning.
341async fn send_frame<W>(stream: &mut W, hdr: Header, msg: &mut [u8]) -> io::Result<()>
342where
343    W: AsyncWriteExt + Unpin,
344{
345    debug_assert!(msg.len() <= MAX_NOISE_MESSAGE_SIZE);
346    msg[..Header::SIZE].copy_from_slice(&hdr.to_bytes());
347    stream.write_all(msg).await?;
348    Ok(())
349}
350
351#[cfg(test)]
352mod tests {
353    use std::net::Ipv4Addr;
354
355    use tokio::net::{TcpListener, TcpStream};
356
357    use super::{Prologue, Result, select_version};
358    use crate::{Config, NetAddr, NetworkError, Version, noise::Protocol, x25519::Keypair};
359
360    fn config<I, V>(versions: I) -> Config
361    where
362        I: IntoIterator<Item = V>,
363        V: Into<Version>,
364    {
365        Config::builder()
366            .name("test")
367            .keypair(Keypair::generate().unwrap())
368            .bind(NetAddr::from((Ipv4Addr::LOCALHOST, 0u16)))
369            .parties([])
370            .noise_protocols(
371                versions
372                    .into_iter()
373                    .map(|v| (v.into(), Protocol::IK_25519_AesGcm_Blake2s)),
374            )
375            .build()
376    }
377
378    async fn negotiate(
379        a: &Config,
380        b: &Config,
381    ) -> (Result<(Version, Prologue)>, Result<(Version, Prologue)>) {
382        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
383        let port = listener.local_addr().unwrap().port();
384        tokio::join!(
385            async {
386                let mut s = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
387                select_version(a, &mut s, true).await
388            },
389            async {
390                let (mut s, _) = listener.accept().await.unwrap();
391                select_version(b, &mut s, false).await
392            },
393        )
394    }
395
396    #[tokio::test]
397    async fn picks_min_of_maxes() {
398        let (ra, rb) = negotiate(&config([1, 2, 3]), &config([1, 2])).await;
399        let (va, pa) = ra.unwrap();
400        let (vb, pb) = rb.unwrap();
401        assert_eq!(va, 2.into());
402        assert_eq!(vb, 2.into());
403        assert_eq!(pa, pb);
404    }
405
406    #[tokio::test]
407    async fn higher_overlap_takes_higher() {
408        let (ra, rb) = negotiate(&config([2, 3, 4]), &config([1, 2, 3])).await;
409        let (va, pa) = ra.unwrap();
410        let (vb, pb) = rb.unwrap();
411        assert_eq!(va, 3.into());
412        assert_eq!(vb, 3.into());
413        assert_eq!(pa, pb);
414    }
415
416    #[tokio::test]
417    async fn single_version_match() {
418        let (ra, rb) = negotiate(&config([1]), &config([1])).await;
419        let (va, pa) = ra.unwrap();
420        let (vb, pb) = rb.unwrap();
421        assert_eq!(va, 1.into());
422        assert_eq!(vb, 1.into());
423        assert_eq!(pa, pb);
424    }
425
426    #[tokio::test]
427    async fn disjoint_ranges_fail() {
428        let (ra, rb) = negotiate(&config([1]), &config([2])).await;
429        assert!(matches!(ra, Err(NetworkError::IncompatibleVersions { .. })));
430        assert!(matches!(rb, Err(NetworkError::IncompatibleVersions { .. })));
431    }
432}