1use std::{
2 cmp::min,
3 io,
4 iter::{once, repeat},
5 net::SocketAddr,
6 sync::Arc,
7 time::Duration,
8};
9
10use rand::RngExt;
11use snow::{Builder, HandshakeState, TransportState, params::NoiseParams};
12use tokio::{
13 io::{AsyncReadExt, AsyncWriteExt},
14 net::TcpStream,
15 time::sleep,
16 try_join,
17};
18use tracing::{debug, warn};
19
20use crate::{
21 Config, Version,
22 addr::NetAddr,
23 error::NetworkError,
24 msg::{Header, MAX_NOISE_MESSAGE_SIZE, hello::Hello},
25 util::until,
26 x25519::PublicKey,
27};
28
29const MAX_NOISE_HANDSHAKE_SIZE: usize = 1024;
30
31type Result<T> = std::result::Result<T, NetworkError>;
32
33pub struct Connection {
34 pub key: PublicKey,
35 pub addr: SocketAddr,
36 pub stream: TcpStream,
37 pub state: TransportState,
38}
39
40type Prologue = Vec<u8>;
41
42impl Connection {
43 pub async fn accept(conf: Arc<Config>, mut stream: TcpStream) -> Result<Self> {
44 let node = conf.keypair.public_key();
45
46 if let Err(err) = stream.set_nodelay(true) {
47 warn!(name = %conf.name, %node, %err, "failed to enable NO_DELAY option")
48 }
49
50 until(conf.handshake_timeout, async move {
51 let (version, prologue) = select_version(&conf, &mut stream, false).await?;
52
53 debug!(name = %conf.name, %node, %version, "negotiated version");
54
55 let noise_proto = conf
56 .noise_protocols
57 .get(&version)
58 .expect("selected version has noise config");
59
60 let hs = Builder::new(noise_proto.noise_params())
61 .local_private_key(&conf.keypair.secret_key().as_bytes())
62 .expect("valid private key")
63 .prologue(&prologue)
64 .expect("1st time we set the prologue")
65 .build_responder()
66 .expect("valid noise params yield valid handshake state");
67
68 let addr = stream.peer_addr()?;
69 let state = on_handshake(&mut stream, hs).await?;
70
71 if let Some(key) = remote_static_key(&state) {
72 Ok(Self {
73 key,
74 addr,
75 stream,
76 state,
77 })
78 } else {
79 warn!(name = %conf.name, %node, %addr, "invalid static key");
80 Err(NetworkError::InvalidHandshakeMessage)
81 }
82 })
83 .await
84 }
85
86 pub async fn connect(conf: Arc<Config>, peer: PublicKey, addr: NetAddr) -> Self {
87 let mut delays = once({
88 if conf.random_connect_delay {
89 Duration::from_millis(rand::rng().random_range(0..1000))
90 } else {
91 Duration::ZERO
92 }
93 })
94 .chain(
95 conf.connect_retry_delays
96 .iter()
97 .map(|&d| Duration::from_secs(d.into())),
98 )
99 .chain(repeat({
100 let d = *conf.connect_retry_delays.last();
101 Duration::from_secs(d.into())
102 }));
103
104 let addr = addr.to_string();
105 let node = conf.keypair.public_key();
106
107 let mut backoff = None;
108
109 loop {
110 if let Some(d) = backoff.take() {
111 sleep(d).await;
112 } else {
113 sleep(delays.next().expect("delays iterator is infinite")).await;
114 }
115
116 debug!(name = %conf.name, %node, %peer, %addr, "connecting");
117
118 match try_connect(&conf, &peer, &addr).await {
119 Ok(mut conn) => {
120 let hello_exchange = until(conf.handshake_timeout, async {
121 conn.send_hello(Hello::Ok).await?;
122 conn.recv_hello().await
123 });
124 match hello_exchange.await {
125 Ok(h) if h.is_ok() => break conn,
126 Ok(h) => {
127 warn!(
128 name = %conf.name,
129 %node,
130 %peer,
131 remote = %conn.key,
132 %addr,
133 "hello response was not ok"
134 );
135 backoff = h.backoff_duration();
136 },
137 Err(err) => {
138 warn!(
139 name = %conf.name,
140 %node,
141 %peer,
142 remote = %conn.key,
143 %addr,
144 %err,
145 "failed to exchange hello"
146 );
147 },
148 }
149 },
150 Err(err) => {
151 warn!(name = %conf.name, %node, %peer, %addr, %err, "connect/handshake error")
152 },
153 }
154 }
155 }
156
157 pub async fn send_hello(&mut self, h: Hello) -> Result<()> {
159 let mut b = [0u8; 64];
160 let n = self
161 .state
162 .write_message(h.to_bytes().as_ref(), &mut b[Header::SIZE..])?;
163 let h = Header::data(n as u16);
164 send_frame(&mut self.stream, h, &mut b[..Header::SIZE + n]).await?;
165 Ok(())
166 }
167
168 pub async fn recv_hello(&mut self) -> Result<Hello> {
170 let mut a = [0u8; 64];
171 let h = recv_frame(&mut self.stream, &mut a).await?;
172 let mut b = [0u8; 64];
173 let n = self.state.read_message(&a[..h.len().into()], &mut b)?;
174 let h = Hello::from_bytes(&b[..n]).ok_or(NetworkError::InvalidHello)?;
175 Ok(h)
176 }
177}
178
179async fn try_connect(conf: &Config, peer: &PublicKey, addr: &str) -> Result<Connection> {
180 let new_handshake_state = |prologue: &Prologue, params: NoiseParams| {
181 Builder::new(params)
182 .local_private_key(conf.keypair.secret_key().as_slice())
183 .expect("valid private key")
184 .remote_public_key(peer.as_slice())
185 .expect("valid remote pub key")
186 .prologue(prologue)
187 .expect("1st time we set the prologue")
188 .build_initiator()
189 .expect("valid noise params yield valid handshake state")
190 };
191
192 let mut stream = until(conf.connect_timeout, TcpStream::connect(addr)).await?;
193
194 let node = conf.keypair.public_key();
195 let addr = stream.peer_addr()?;
196
197 debug!(name = %conf.name, %node, %peer, %addr, "tcp connection established");
198
199 if let Err(err) = stream.set_nodelay(true) {
200 warn!(name = %conf.name, %node, %err, "failed to enable NO_DELAY option");
201 }
202
203 until(conf.handshake_timeout, async move {
204 let (version, prologue) = select_version(conf, &mut stream, true).await?;
205
206 debug!(name = %conf.name, %node, %peer, %addr, %version, "negotiated version");
207
208 let noise_proto = conf
209 .noise_protocols
210 .get(&version)
211 .expect("selected version has noise config");
212
213 let hshake = new_handshake_state(&prologue, noise_proto.noise_params());
214 let state = handshake(&mut stream, hshake).await?;
215 match remote_static_key(&state) {
216 Some(key) if key == *peer => Ok(Connection {
217 key,
218 addr,
219 stream,
220 state,
221 }),
222 Some(key) => {
223 warn!(name = %conf.name, %node, %peer, remote = %key, %addr, "static key mismatch");
224 Err(NetworkError::InvalidHandshakeMessage)
225 },
226 None => {
227 warn!(name = %conf.name, %node, %peer, %addr, "invalid static key");
228 Err(NetworkError::InvalidHandshakeMessage)
229 },
230 }
231 })
232 .await
233}
234
235fn remote_static_key(state: &TransportState) -> Option<PublicKey> {
236 let k = state.get_remote_static()?;
237 PublicKey::try_from(k).ok()
238}
239
240async fn select_version(
244 conf: &Config,
245 stream: &mut TcpStream,
246 is_initiator: bool,
247) -> Result<(Version, Prologue)> {
248 const INIT_PAYLOAD_LEN: usize = 4;
249
250 let our_min = *conf.noise_protocols.first().0;
251 let our_max = *conf.noise_protocols.last().0;
252
253 let mut send_buf = [0u8; INIT_PAYLOAD_LEN];
254 let mut recv_buf = [0u8; INIT_PAYLOAD_LEN];
255
256 send_buf[0..2].copy_from_slice(&u16::from(our_min).to_be_bytes());
257 send_buf[2..4].copy_from_slice(&u16::from(our_max).to_be_bytes());
258
259 let (mut r, mut w) = stream.split();
260 try_join!(w.write_all(&send_buf), r.read_exact(&mut recv_buf))?;
261
262 let their_min = Version::from(u16::from_be_bytes([recv_buf[0], recv_buf[1]]));
263 let their_max = Version::from(u16::from_be_bytes([recv_buf[2], recv_buf[3]]));
264
265 let selected = min(our_max, their_max);
266
267 if selected < their_min || selected < our_min {
268 return Err(NetworkError::IncompatibleVersions {
269 ours: (our_min, our_max),
270 theirs: (their_min, their_max),
271 });
272 }
273
274 let mut prologue = Vec::new();
278 prologue.extend_from_slice(conf.name.as_bytes());
279 if is_initiator {
280 prologue.extend_from_slice(&send_buf);
281 prologue.extend_from_slice(&recv_buf);
282 } else {
283 prologue.extend_from_slice(&recv_buf);
284 prologue.extend_from_slice(&send_buf);
285 }
286
287 Ok((selected, prologue))
288}
289
290async fn handshake(stream: &mut TcpStream, mut hs: HandshakeState) -> Result<TransportState> {
292 let mut a = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
293 let n = hs.write_message(&[], &mut a[Header::SIZE..])?;
294 let h = Header::data(n as u16);
295 send_frame(stream, h, &mut a[..Header::SIZE + n]).await?;
296 let mut b = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
297 let h = recv_frame(stream, &mut b).await?;
298 if !h.is_data() || h.is_partial() {
299 return Err(NetworkError::InvalidHandshakeMessage);
300 }
301 hs.read_message(&b[..h.len().into()], &mut a)?;
302 Ok(hs.into_transport_mode()?)
303}
304
305async fn on_handshake(stream: &mut TcpStream, mut hs: HandshakeState) -> Result<TransportState> {
307 let mut a = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
308 let h = recv_frame(stream, &mut a).await?;
309 if !h.is_data() || h.is_partial() {
310 return Err(NetworkError::InvalidHandshakeMessage);
311 }
312 let mut b = [0u8; MAX_NOISE_HANDSHAKE_SIZE];
313 hs.read_message(&a[..h.len().into()], &mut b)?;
314 let n = hs.write_message(&[], &mut b[Header::SIZE..])?;
315 let h = Header::data(n as u16);
316 send_frame(stream, h, &mut b[..Header::SIZE + n]).await?;
317 Ok(hs.into_transport_mode()?)
318}
319
320async fn recv_frame<R, const N: usize>(stream: &mut R, buf: &mut [u8; N]) -> io::Result<Header>
322where
323 R: AsyncReadExt + Unpin,
324{
325 let h = {
326 let n = stream.read_u32().await?;
327 Header::unvalidated(n)
328 };
329 let n = h.len().into();
330 if n > N {
331 return Err(io::ErrorKind::InvalidInput.into());
332 }
333 stream.read_exact(&mut buf[..n]).await?;
334 Ok(h)
335}
336
337async fn send_frame<W>(stream: &mut W, hdr: Header, msg: &mut [u8]) -> io::Result<()>
342where
343 W: AsyncWriteExt + Unpin,
344{
345 debug_assert!(msg.len() <= MAX_NOISE_MESSAGE_SIZE);
346 msg[..Header::SIZE].copy_from_slice(&hdr.to_bytes());
347 stream.write_all(msg).await?;
348 Ok(())
349}
350
351#[cfg(test)]
352mod tests {
353 use std::net::Ipv4Addr;
354
355 use tokio::net::{TcpListener, TcpStream};
356
357 use super::{Prologue, Result, select_version};
358 use crate::{Config, NetAddr, NetworkError, Version, noise::Protocol, x25519::Keypair};
359
360 fn config<I, V>(versions: I) -> Config
361 where
362 I: IntoIterator<Item = V>,
363 V: Into<Version>,
364 {
365 Config::builder()
366 .name("test")
367 .keypair(Keypair::generate().unwrap())
368 .bind(NetAddr::from((Ipv4Addr::LOCALHOST, 0u16)))
369 .parties([])
370 .noise_protocols(
371 versions
372 .into_iter()
373 .map(|v| (v.into(), Protocol::IK_25519_AesGcm_Blake2s)),
374 )
375 .build()
376 }
377
378 async fn negotiate(
379 a: &Config,
380 b: &Config,
381 ) -> (Result<(Version, Prologue)>, Result<(Version, Prologue)>) {
382 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
383 let port = listener.local_addr().unwrap().port();
384 tokio::join!(
385 async {
386 let mut s = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
387 select_version(a, &mut s, true).await
388 },
389 async {
390 let (mut s, _) = listener.accept().await.unwrap();
391 select_version(b, &mut s, false).await
392 },
393 )
394 }
395
396 #[tokio::test]
397 async fn picks_min_of_maxes() {
398 let (ra, rb) = negotiate(&config([1, 2, 3]), &config([1, 2])).await;
399 let (va, pa) = ra.unwrap();
400 let (vb, pb) = rb.unwrap();
401 assert_eq!(va, 2.into());
402 assert_eq!(vb, 2.into());
403 assert_eq!(pa, pb);
404 }
405
406 #[tokio::test]
407 async fn higher_overlap_takes_higher() {
408 let (ra, rb) = negotiate(&config([2, 3, 4]), &config([1, 2, 3])).await;
409 let (va, pa) = ra.unwrap();
410 let (vb, pb) = rb.unwrap();
411 assert_eq!(va, 3.into());
412 assert_eq!(vb, 3.into());
413 assert_eq!(pa, pb);
414 }
415
416 #[tokio::test]
417 async fn single_version_match() {
418 let (ra, rb) = negotiate(&config([1]), &config([1])).await;
419 let (va, pa) = ra.unwrap();
420 let (vb, pb) = rb.unwrap();
421 assert_eq!(va, 1.into());
422 assert_eq!(vb, 1.into());
423 assert_eq!(pa, pb);
424 }
425
426 #[tokio::test]
427 async fn disjoint_ranges_fail() {
428 let (ra, rb) = negotiate(&config([1]), &config([2])).await;
429 assert!(matches!(ra, Err(NetworkError::IncompatibleVersions { .. })));
430 assert!(matches!(rb, Err(NetworkError::IncompatibleVersions { .. })));
431 }
432}