1use std::{
2 cmp::{Ordering, min},
3 fmt::{self, Debug, Display, Formatter},
4 num::ParseIntError,
5 str::FromStr,
6 time::Duration,
7};
8
9use anyhow::Context;
10use bytesize::ByteSize;
11use clap::Parser;
12use committable::Committable;
13use derive_more::{From, Into};
14use espresso_utils::{impl_serde_from_string_or_integer, ser::FromStringOrInteger};
15use futures::future::BoxFuture;
16use hotshot_types::{
17 consensus::CommitmentMap,
18 data::{Leaf, Leaf2},
19 traits::node_implementation::NodeType,
20};
21use rand::Rng;
22use serde::{Deserialize, Serialize};
23use thiserror::Error;
24use time::{
25 Date, OffsetDateTime, format_description::well_known::Rfc3339 as TimestampFormat, macros::time,
26};
27use tokio::time::sleep;
28
29use crate::ChainConfig;
30
31pub fn upgrade_commitment_map<Types: NodeType>(
32 map: CommitmentMap<Leaf<Types>>,
33) -> CommitmentMap<Leaf2<Types>> {
34 map.into_values()
35 .map(|leaf| {
36 let leaf2: Leaf2<Types> = leaf.into();
37 (leaf2.commit(), leaf2)
38 })
39 .collect()
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
43pub enum Update<T> {
44 #[default]
45 #[serde(rename = "__skip")]
46 Skip,
47 #[serde(untagged)]
48 Set(T),
49}
50
51impl<T> Update<T> {
52 pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Update<U> {
53 match self {
54 Update::Skip => Update::Skip,
55 Update::Set(v) => Update::Set(f(v)),
56 }
57 }
58}
59
60#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
62pub struct GenesisHeader {
63 pub timestamp: Timestamp,
64 pub chain_config: ChainConfig,
65}
66
67#[derive(Hash, Copy, Clone, Debug, derive_more::Display, PartialEq, Eq, From, Into)]
68#[display("{}", _0.format(&TimestampFormat).unwrap())]
69pub struct Timestamp(OffsetDateTime);
70
71impl_serde_from_string_or_integer!(Timestamp);
72
73impl Default for Timestamp {
74 fn default() -> Self {
75 Self::from_integer(0).unwrap()
76 }
77}
78
79impl Timestamp {
80 pub fn unix_timestamp(&self) -> u64 {
81 self.0.unix_timestamp() as u64
82 }
83
84 pub fn unix_timestamp_millis(&self) -> u64 {
85 (self.0.unix_timestamp_nanos() / 1_000_000) as u64
86 }
87
88 pub fn max() -> Self {
89 Self(OffsetDateTime::new_utc(Date::MAX, time!(23:59)))
90 }
91}
92
93impl FromStringOrInteger for Timestamp {
94 type Binary = u64;
95 type Integer = u64;
96
97 fn from_binary(b: Self::Binary) -> anyhow::Result<Self> {
98 Self::from_integer(b)
99 }
100
101 fn from_integer(i: Self::Integer) -> anyhow::Result<Self> {
102 let unix = i.try_into().context("timestamp out of range")?;
103 Ok(Self(
104 OffsetDateTime::from_unix_timestamp(unix).context("invalid timestamp")?,
105 ))
106 }
107
108 fn from_string(s: String) -> anyhow::Result<Self> {
109 Ok(Self(
110 OffsetDateTime::parse(&s, &TimestampFormat).context("invalid timestamp")?,
111 ))
112 }
113
114 fn to_binary(&self) -> anyhow::Result<Self::Binary> {
115 Ok(self.unix_timestamp())
116 }
117
118 fn to_string(&self) -> anyhow::Result<String> {
119 Ok(format!("{self}"))
120 }
121}
122
123#[derive(
124 Hash,
125 Copy,
126 Clone,
127 Debug,
128 derive_more::Display,
129 PartialEq,
130 Eq,
131 PartialOrd,
132 Ord,
133 Serialize,
134 Deserialize,
135)]
136#[display("{}", _0)]
137pub struct TimestampMillis(u64);
138
139impl TimestampMillis {
140 pub fn from_time(time: &OffsetDateTime) -> Self {
141 let timestamp = (time.unix_timestamp_nanos() / 1_000_000) as u64;
142
143 Self(timestamp)
144 }
145
146 pub fn from_millis(millis: u64) -> Self {
147 Self(millis)
148 }
149
150 pub fn u64(&self) -> u64 {
151 self.0
152 }
153}
154
155#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
156pub struct Ratio {
157 pub numerator: u64,
158 pub denominator: u64,
159}
160
161impl From<Ratio> for (u64, u64) {
162 fn from(r: Ratio) -> Self {
163 (r.numerator, r.denominator)
164 }
165}
166
167impl Display for Ratio {
168 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
169 write!(f, "{}:{}", self.numerator, self.denominator)
170 }
171}
172
173impl PartialOrd for Ratio {
174 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
175 Some(self.cmp(other))
176 }
177}
178
179impl Ord for Ratio {
180 fn cmp(&self, other: &Self) -> Ordering {
181 (self.numerator * other.denominator).cmp(&(other.numerator * self.denominator))
182 }
183}
184
185#[derive(Debug, Error)]
186pub enum ParseRatioError {
187 #[error("numerator and denominator must be separated by :")]
188 MissingDelimiter,
189 #[error("Invalid numerator {err:?}")]
190 InvalidNumerator { err: ParseIntError },
191 #[error("Invalid denominator {err:?}")]
192 InvalidDenominator { err: ParseIntError },
193}
194
195impl FromStr for Ratio {
196 type Err = ParseRatioError;
197
198 fn from_str(s: &str) -> Result<Self, Self::Err> {
199 let (num, den) = s.split_once(':').ok_or(ParseRatioError::MissingDelimiter)?;
200 Ok(Self {
201 numerator: num
202 .parse()
203 .map_err(|err| ParseRatioError::InvalidNumerator { err })?,
204 denominator: den
205 .parse()
206 .map_err(|err| ParseRatioError::InvalidDenominator { err })?,
207 })
208 }
209}
210
211#[derive(Clone, Debug, Error)]
212#[error("Failed to parse duration {reason}")]
213pub struct ParseDurationError {
214 reason: String,
215}
216
217pub fn parse_duration(s: &str) -> Result<Duration, ParseDurationError> {
218 cld::ClDuration::from_str(s)
219 .map(Duration::from)
220 .map_err(|err| ParseDurationError {
221 reason: err.to_string(),
222 })
223}
224
225#[derive(Clone, Debug, From, Error)]
226#[error("failed to parse ByteSize. {msg}")]
227pub struct ParseSizeError {
228 msg: String,
229}
230
231pub fn parse_size(s: &str) -> Result<u64, ParseSizeError> {
232 Ok(s.parse::<ByteSize>()?.0)
233}
234
235pub const MIN_RETRY_DELAY: Duration = Duration::from_millis(500);
236pub const MAX_RETRY_DELAY: Duration = Duration::from_secs(5);
237pub const BACKOFF_FACTOR: u32 = 2;
238pub const BACKOFF_JITTER: (u64, u64) = (1, 10);
240
241#[derive(Clone, Copy, Debug, Parser, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
242pub struct BackoffParams {
243 #[clap(
245 long = "catchup-backoff-factor",
246 env = "ESPRESSO_NODE_CATCHUP_BACKOFF_FACTOR",
247 default_value = "4"
248 )]
249 factor: u32,
250
251 #[clap(
253 long = "catchup-base-retry-delay",
254 env = "ESPRESSO_NODE_CATCHUP_BASE_RETRY_DELAY",
255 default_value = "200ms",
256 value_parser = parse_duration
257 )]
258 base: Duration,
259
260 #[clap(
262 long = "catchup-max-retry-delay",
263 env = "ESPRESSO_NODE_CATCHUP_MAX_RETRY_DELAY",
264 default_value = "5s",
265 value_parser = parse_duration
266 )]
267 max: Duration,
268
269 #[clap(
271 long = "catchup-backoff-jitter",
272 env = "ESPRESSO_NODE_CATCHUP_BACKOFF_JITTER",
273 default_value = "1:10"
274 )]
275 jitter: Ratio,
276
277 #[clap(short, long, env = "ESPRESSO_NODE_CATCHUP_BACKOFF_DISABLE")]
279 disable: bool,
280}
281
282impl Default for BackoffParams {
283 fn default() -> Self {
284 Self::parse_from(std::iter::empty::<String>())
285 }
286}
287
288impl BackoffParams {
289 pub const fn new(base: Duration, max: Duration, factor: u32, jitter: Ratio) -> Self {
290 Self {
291 base,
292 max,
293 factor,
294 jitter,
295 disable: false,
296 }
297 }
298
299 pub fn disabled() -> Self {
300 Self {
301 disable: true,
302 ..Default::default()
303 }
304 }
305
306 pub async fn retry<S, T>(
307 &self,
308 mut state: S,
309 f: impl for<'a> Fn(&'a mut S, usize) -> BoxFuture<'a, anyhow::Result<T>>,
310 ) -> anyhow::Result<T> {
311 let mut delay = self.base;
312 for i in 0.. {
313 match f(&mut state, i).await {
314 Ok(res) => return Ok(res),
315 Err(err) if self.disable => {
316 return Err(err.context("Retryable operation failed; retries disabled"));
317 },
318 Err(err) => {
319 tracing::warn!(
320 "Retryable operation failed, will retry after {delay:?}: {err:#}"
321 );
322 sleep(delay).await;
323 delay = self.backoff(delay);
324 },
325 }
326 }
327 unreachable!()
328 }
329
330 pub async fn retry_if<Fut, T>(
334 &self,
335 max_retries: u32,
336 should_retry: impl Fn(&anyhow::Error) -> bool,
337 f: impl Fn() -> Fut,
338 ) -> anyhow::Result<T>
339 where
340 Fut: Future<Output = anyhow::Result<T>>,
341 {
342 let mut delay = self.base;
343 for i in 0usize.. {
344 match f().await {
345 Ok(res) => return Ok(res),
346 Err(err) if self.disable => {
347 return Err(err.context("Retryable operation failed; retries disabled"));
348 },
349 Err(err) if (i as u32) < max_retries && should_retry(&err) => {
350 tracing::warn!(
351 attempt = i + 1,
352 max_retries,
353 delay_ms = delay.as_millis(),
354 "Retryable operation failed, will retry after {delay:?}: {err:#}"
355 );
356 sleep(delay).await;
357 delay = self.backoff(delay);
358 },
359 Err(err) => return Err(err),
360 }
361 }
362 unreachable!()
363 }
364
365 #[must_use]
366 pub fn backoff(&self, delay: Duration) -> Duration {
367 if delay >= self.max {
368 return self.max;
369 }
370
371 let mut rng = rand::thread_rng();
372
373 let ms = (delay * self.factor).as_millis() as u64;
375
376 let jitter_num = rng.gen_range(0..self.jitter.numerator);
378 let jitter_den = self.jitter.denominator;
379
380 let jitter = ms * jitter_num / jitter_den;
382 let delay = Duration::from_millis(ms + jitter);
383
384 min(delay, self.max)
386 }
387}