1use std::{
2 io::{Cursor, Read, Write},
3 time::{Duration, SystemTime, UNIX_EPOCH},
4};
5
6use anyhow::{Context, Result};
7use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
8use hotshot_types::traits::signature_key::SignatureKey;
9
10use super::{RequestHash, Serializable, request::Request};
11
12#[derive(Clone, Debug)]
14#[cfg_attr(test, derive(PartialEq, Eq))]
15pub enum Message<R: Request, K: SignatureKey> {
16 Request(RequestMessage<R, K>),
18 Response(ResponseMessage<R>),
20}
21
22#[derive(Clone, Debug)]
24#[cfg_attr(test, derive(PartialEq, Eq))]
25pub struct RequestMessage<R: Request, K: SignatureKey> {
26 pub public_key: K,
28 pub signature: K::PureAssembledSignatureType,
30 pub timestamp_unix_seconds: u64,
33 pub request: R,
35}
36
37#[derive(Clone, Debug)]
39#[cfg_attr(test, derive(PartialEq, Eq))]
40pub struct ResponseMessage<R: Request> {
41 pub request_hash: RequestHash,
44 pub response: R::Response,
46}
47
48impl<R: Request, K: SignatureKey> RequestMessage<R, K> {
49 pub fn new_signed(public_key: &K, private_key: &K::PrivateKey, request: &R) -> Result<Self>
58 where
59 <K as SignatureKey>::SignError: 'static,
60 {
61 let timestamp_unix_seconds = SystemTime::now()
63 .duration_since(UNIX_EPOCH)
64 .expect("time went backwards")
65 .as_secs();
66
67 let content_to_sign = [
69 request
70 .to_bytes()
71 .with_context(|| "failed to serialize request content")?
72 .as_slice(),
73 timestamp_unix_seconds.to_le_bytes().as_slice(),
74 b"espresso-request-response",
75 ]
76 .concat();
77
78 let signature =
80 K::sign(private_key, &content_to_sign).with_context(|| "failed to sign message")?;
81
82 Ok(RequestMessage {
84 public_key: public_key.clone(),
85 signature,
86 timestamp_unix_seconds,
87 request: request.clone(),
88 })
89 }
90
91 pub fn validate(&self, incoming_request_ttl: Duration) -> Result<()> {
101 if self
103 .timestamp_unix_seconds
104 .saturating_add(incoming_request_ttl.as_secs())
105 < SystemTime::now()
106 .duration_since(UNIX_EPOCH)
107 .expect("time went backwards")
108 .as_secs()
109 {
110 return Err(anyhow::anyhow!("request is too old"));
111 }
112 if !self.public_key.validate(
114 &self.signature,
115 &[
116 self.request.to_bytes()?,
117 self.timestamp_unix_seconds.to_le_bytes().to_vec(),
118 b"espresso-request-response".to_vec(),
119 ]
120 .concat(),
121 ) {
122 return Err(anyhow::anyhow!("invalid request signature"));
123 }
124
125 self.request.validate()
127 }
128}
129
130impl<R: Request, K: SignatureKey> Serializable for Message<R, K> {
132 fn to_bytes(&self) -> Result<Vec<u8>> {
134 let mut bytes = Vec::new();
136
137 match self {
139 Message::Request(request_message) => {
140 bytes.push(0);
142
143 bytes.extend_from_slice(request_message.to_bytes()?.as_slice());
145 },
146 Message::Response(response_message) => {
147 bytes.push(1);
149
150 bytes.extend_from_slice(response_message.to_bytes()?.as_slice());
152 },
153 };
154
155 Ok(bytes)
156 }
157
158 fn from_bytes(bytes: &[u8]) -> Result<Self> {
160 let mut bytes = Cursor::new(bytes);
162
163 let type_byte = bytes.read_u8()?;
165
166 match type_byte {
168 0 => {
169 Ok(Message::Request(RequestMessage::from_bytes(&read_to_end(
171 &mut bytes,
172 )?)?))
173 },
174 1 => {
175 Ok(Message::Response(ResponseMessage::from_bytes(
177 &read_to_end(&mut bytes)?,
178 )?))
179 },
180 _ => Err(anyhow::anyhow!("invalid message type")),
181 }
182 }
183}
184
185impl<R: Request, K: SignatureKey> Serializable for RequestMessage<R, K> {
186 fn to_bytes(&self) -> Result<Vec<u8>> {
187 let mut bytes = Vec::new();
189
190 write_length_prefixed(&mut bytes, &self.public_key.to_bytes())?;
192
193 write_length_prefixed(&mut bytes, &bincode::serialize(&self.signature)?)?;
195
196 bytes.write_all(&self.timestamp_unix_seconds.to_le_bytes())?;
198
199 bytes.write_all(self.request.to_bytes()?.as_slice())?;
201
202 Ok(bytes)
203 }
204
205 fn from_bytes(bytes: &[u8]) -> Result<Self> {
206 let mut bytes = Cursor::new(bytes);
208
209 let public_key = K::from_bytes(&read_length_prefixed(&mut bytes)?)?;
211
212 let signature = bincode::deserialize(&read_length_prefixed(&mut bytes)?)?;
214
215 let timestamp = bytes.read_u64::<LittleEndian>()?;
217
218 let request = R::from_bytes(&read_to_end(&mut bytes)?)?;
220
221 Ok(Self {
222 public_key,
223 signature,
224 timestamp_unix_seconds: timestamp,
225 request,
226 })
227 }
228}
229
230impl<R: Request> Serializable for ResponseMessage<R> {
231 fn to_bytes(&self) -> Result<Vec<u8>> {
232 let mut bytes = Vec::new();
234
235 bytes.write_all(self.request_hash.as_bytes())?;
237
238 bytes.write_all(self.response.to_bytes()?.as_slice())?;
240
241 Ok(bytes)
242 }
243
244 fn from_bytes(bytes: &[u8]) -> Result<Self> {
245 let mut bytes = Cursor::new(bytes);
247
248 let mut request_hash_bytes = [0; 32];
250 bytes.read_exact(&mut request_hash_bytes)?;
251 let request_hash = RequestHash::from(request_hash_bytes);
252
253 let response = R::Response::from_bytes(&read_to_end(&mut bytes)?)?;
255
256 Ok(Self {
257 request_hash,
258 response,
259 })
260 }
261}
262
263fn write_length_prefixed<W: Write>(writer: &mut W, value: &[u8]) -> Result<()> {
265 writer.write_u32::<LittleEndian>(
267 u32::try_from(value.len()).with_context(|| "value was too large")?,
268 )?;
269
270 writer.write_all(value)?;
272 Ok(())
273}
274
275fn read_length_prefixed<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
277 let length = reader.read_u32::<LittleEndian>()?;
279
280 let mut value = vec![0; length as usize];
282 reader.read_exact(&mut value)?;
283 Ok(value)
284}
285
286fn read_to_end<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
288 let mut value = Vec::new();
289 reader.read_to_end(&mut value)?;
290 Ok(value)
291}
292
293#[cfg(test)]
294mod tests {
295 use hotshot_types::signature_key::BLSPubKey;
296 use rand::Rng;
297
298 use super::*;
299
300 impl Serializable for Vec<u8> {
302 fn to_bytes(&self) -> Result<Vec<u8>> {
303 Ok(self.clone())
304 }
305 fn from_bytes(bytes: &[u8]) -> Result<Self> {
306 Ok(bytes.to_vec())
307 }
308 }
309
310 impl Request for Vec<u8> {
312 type Response = Vec<u8>;
313
314 fn validate(&self) -> Result<()> {
315 Ok(())
316 }
317 }
318
319 #[test]
322 fn test_request_validation() {
323 let mut rng = rand::thread_rng();
325
326 for _ in 0..100 {
327 let (public_key, private_key) =
329 BLSPubKey::generated_from_seed_indexed([1; 32], rng.r#gen::<u64>());
330
331 let mut request = RequestMessage::new_signed(
333 &public_key,
334 &private_key,
335 &vec![rng.r#gen::<u8>(); rng.gen_range(1..10000)],
336 )
337 .expect("Failed to create signed request");
338
339 let (should_be_valid, request_ttl) = match rng.gen_range(0..4) {
340 0 => (true, Duration::from_secs(1)),
341
342 1 => {
343 request.request[0] = !request.request[0];
345
346 (false, Duration::from_secs(1))
348 },
349
350 2 => {
351 request.timestamp_unix_seconds += 1000;
353
354 (false, Duration::from_secs(1))
356 },
357
358 3 => {
359 (true, Duration::from_secs(0))
362 },
363
364 _ => unreachable!(),
365 };
366
367 assert_eq!(request.validate(request_ttl).is_ok(), should_be_valid);
369 }
370 }
371
372 #[test]
374 fn test_message_parity() {
375 for _ in 0..100 {
376 let mut rng = rand::thread_rng();
378
379 let is_request = rng.r#gen::<u8>() % 2 == 0;
381
382 let request = vec![rng.r#gen::<u8>(); rng.gen_range(0..10000)];
384
385 let message = if is_request {
387 let (public_key, private_key) =
389 BLSPubKey::generated_from_seed_indexed([1; 32], rng.r#gen::<u64>());
390
391 let request = RequestMessage::new_signed(&public_key, &private_key, &request)
393 .expect("Failed to create signed request");
394
395 Message::Request(request)
396 } else {
397 Message::Response(ResponseMessage {
399 request_hash: blake3::hash(&request),
400 response: vec![rng.r#gen::<u8>(); rng.gen_range(0..10000)],
401 })
402 };
403
404 let serialized = message.to_bytes().expect("Failed to serialize message");
406
407 let deserialized =
409 Message::from_bytes(&serialized).expect("Failed to deserialize message");
410
411 assert_eq!(message, deserialized);
413 }
414 }
415
416 #[test]
418 fn test_length_prefix_parity() {
419 let mut rng = rand::thread_rng();
421
422 for _ in 0..100 {
423 let mut bytes = Vec::new();
425
426 let value = vec![rng.r#gen::<u8>(); rng.gen_range(0..10000)];
428
429 write_length_prefixed(&mut bytes, &value).unwrap();
431
432 let mut reader = Cursor::new(bytes);
434
435 let value = read_length_prefixed(&mut reader).unwrap();
437 assert_eq!(value, value);
438 }
439 }
440}