light_client/
storage.rs

1use std::{future::Future, path::PathBuf, str::FromStr};
2
3use alloy::primitives::Address;
4use anyhow::{Context, Result};
5use derive_more::{Display, From};
6use espresso_types::{PubKey, SeqTypes, StakeTableState, v0_3::RegisteredValidator};
7use futures::TryStreamExt;
8use hotshot_query_service::{
9    availability::{BlockId, LeafId, LeafQueryData},
10    types::HeightIndexed,
11};
12use hotshot_types::{data::EpochNumber, light_client::StateVerKey};
13use serde_json::Value;
14use sqlx::{
15    QueryBuilder, SqlitePool, query, query_as,
16    sqlite::{SqliteConnectOptions, SqlitePoolOptions},
17};
18
19/// Different ways to ask the database for a leaf.
20#[derive(Clone, Copy, Debug, Display, From)]
21pub enum LeafRequest {
22    /// Ask for a leaf with a given ID.
23    #[display("leaf {_0}")]
24    Leaf(LeafId<SeqTypes>),
25
26    /// Ask for the leaf containing a header with a given ID.
27    #[display("header {_0}")]
28    Header(BlockId<SeqTypes>),
29}
30
31/// Client-side database for a [`LightClient`].
32pub trait Storage: Sized + Send + Sync + 'static {
33    /// Create a default, empty instance of the state.
34    ///
35    /// This is an async, fallible version of [`Default::default`]. If `Self: Default`, this is
36    /// equivalent to `ready(Ok(<Self as Default>::default()))`.
37    fn default() -> impl Send + Future<Output = Result<Self>>;
38
39    /// Get the number of blocks known to be in the chain.
40    ///
41    /// This is equivalent to one more than the block number of the latest known block.
42    ///
43    /// Because the database is not constantly being updated, this may be an underestimate of the
44    /// true number of blocks that exist.
45    fn block_height(&self) -> impl Send + Future<Output = Result<u64>>;
46
47    /// Get the earliest available leaf which is later than or equal to the requested leaf.
48    ///
49    /// This will either be the leaf requested, or can be used as the known-finalized endpoint in a
50    /// leaf chain proving that requested leaf is finalized (after the requested leaf is fetched
51    /// from elsewhere).
52    ///
53    /// If there is no known leaf later than the requested leaf, the result is [`None`].
54    fn leaf_upper_bound(
55        &self,
56        leaf: impl Into<LeafRequest> + Send,
57    ) -> impl Send + Future<Output = Result<Option<LeafQueryData<SeqTypes>>>>;
58
59    /// Get all leaves in the range [start, end)
60    fn get_leaves_in_range(
61        &self,
62        start: u32,
63        end: u32,
64    ) -> impl Send + Future<Output = Result<Vec<LeafQueryData<SeqTypes>>>>;
65
66    /// Add a leaf to the cache.
67    ///
68    /// This may result in an older leaf being removed.
69    fn insert_leaf(&self, leaf: LeafQueryData<SeqTypes>)
70    -> impl Send + Future<Output = Result<()>>;
71
72    /// Get the stake table for the latest epoch which is not later than `epoch`.
73    ///
74    /// If such a stake table is available in the database, returns the ordered entries and the
75    /// epoch number of the stake table that was loaded.
76    fn stake_table_lower_bound(
77        &self,
78        epoch: EpochNumber,
79    ) -> impl Send + Future<Output = Result<Option<(EpochNumber, StakeTableState)>>>;
80
81    /// Add a stake table to the cache.
82    ///
83    /// This may result in an older stake table being removed.
84    fn insert_stake_table(
85        &self,
86        epoch: EpochNumber,
87        stake_table: &StakeTableState,
88    ) -> impl Send + Future<Output = Result<()>>;
89}
90
91#[derive(Clone, Debug)]
92#[cfg_attr(feature = "clap", derive(clap::Parser))]
93pub struct LightClientSqliteOptions {
94    /// Maximum number of simultaneous DB connections to allow.
95    #[cfg_attr(
96        feature = "clap",
97        clap(
98            long = "light-client-db-num-connections",
99            env = "LIGHT_CLIENT_DB_NUM_CONNECTIONS",
100            default_value = "5",
101        )
102    )]
103    pub num_connections: u32,
104
105    /// Maximum number of leaves to cache in the local DB.
106    #[cfg_attr(
107        feature = "clap",
108        clap(
109            long = "light-client-db-num-leaves",
110            env = "LIGHT_CLIENT_DB_NUM_LEAVES",
111            default_value = "100",
112        )
113    )]
114    pub num_leaves: u32,
115
116    /// Maximum number of stake tables to cache in the local DB.
117    #[cfg_attr(
118        feature = "clap",
119        clap(
120            long = "light-client-db-num-stake-tables",
121            env = "LIGHT_CLIENT_DB_NUM_STAKE_TABLES",
122            default_value = "100",
123        )
124    )]
125    pub num_stake_tables: u32,
126
127    /// Create or open storage that is persisted on the file system.
128    ///
129    /// If not present, the database will exist only in memory and will be destroyed when the
130    /// [`SqlitePersistence`] object is dropped.
131    #[cfg_attr(
132        feature = "clap",
133        clap(long = "light-client-db-path", env = "LIGHT_CLIENT_DB_PATH")
134    )]
135    pub lc_path: Option<PathBuf>,
136}
137
138impl Default for LightClientSqliteOptions {
139    fn default() -> Self {
140        Self {
141            num_connections: 5,
142            num_leaves: 100,
143            num_stake_tables: 100,
144            lc_path: None,
145        }
146    }
147}
148
149impl LightClientSqliteOptions {
150    /// Create or connect to a database with the given options.
151    pub async fn connect(self) -> Result<SqliteStorage> {
152        let path = match &self.lc_path {
153            Some(path) => path.to_str().context("invalid file path")?,
154            None => ":memory:",
155        };
156        let opt = SqliteConnectOptions::from_str(path)?.create_if_missing(true);
157        let pool = SqlitePoolOptions::default()
158            .max_connections(self.num_connections)
159            .connect_with(opt)
160            .await?;
161        sqlx::migrate!("./migrations").run(&pool).await?;
162
163        Ok(SqliteStorage {
164            pool,
165            num_leaves: self.num_leaves,
166            num_stake_tables: self.num_stake_tables,
167        })
168    }
169}
170
171/// [`Storage`] based on a SQLite database.
172#[derive(Clone, Debug)]
173pub struct SqliteStorage {
174    pool: SqlitePool,
175    num_leaves: u32,
176    num_stake_tables: u32,
177}
178
179impl Storage for SqliteStorage {
180    async fn default() -> Result<Self> {
181        LightClientSqliteOptions::default().connect().await
182    }
183
184    async fn block_height(&self) -> Result<u64> {
185        let mut tx = self.pool.begin().await?;
186        let (height,) = query_as("SELECT COALESCE(max(height) + 1, 0) FROM leaf")
187            .fetch_one(tx.as_mut())
188            .await?;
189        Ok(height)
190    }
191
192    async fn leaf_upper_bound(
193        &self,
194        id: impl Into<LeafRequest> + Send,
195    ) -> Result<Option<LeafQueryData<SeqTypes>>> {
196        let mut tx = self.pool.begin().await?;
197
198        let mut q = QueryBuilder::new("SELECT height, data FROM leaf WHERE ");
199        match id.into() {
200            LeafRequest::Leaf(LeafId::Number(n)) | LeafRequest::Header(BlockId::Number(n)) => {
201                q.push("height >= ")
202                    .push_bind(n as i64)
203                    .push("ORDER BY HEIGHT");
204            },
205            LeafRequest::Leaf(LeafId::Hash(h)) => {
206                q.push("hash = ").push_bind(h.to_string());
207            },
208            LeafRequest::Header(BlockId::Hash(h)) => {
209                q.push("block_hash = ").push_bind(h.to_string());
210            },
211            LeafRequest::Header(BlockId::PayloadHash(h)) => {
212                q.push("payload_hash = ")
213                    .push_bind(h.to_string())
214                    .push("ORDER BY height");
215            },
216        }
217        q.push(" LIMIT 1");
218
219        let Some((height, data)) = q
220            .build_query_as::<(i64, _)>()
221            .fetch_optional(tx.as_mut())
222            .await?
223        else {
224            return Ok(None);
225        };
226        let leaf = serde_json::from_value(data)?;
227
228        // Mark this leaf as recently used.
229        let (id,): (i32,) = query_as("SELECT max(id) + 1 FROM leaf")
230            .fetch_one(tx.as_mut())
231            .await?;
232        query("UPDATE leaf SET id = $1 WHERE height = $2")
233            .bind(id)
234            .bind(height)
235            .execute(tx.as_mut())
236            .await?;
237        tx.commit().await?;
238
239        Ok(Some(leaf))
240    }
241
242    async fn get_leaves_in_range(
243        &self,
244        start_height: u32,
245        end_height: u32,
246    ) -> Result<Vec<LeafQueryData<SeqTypes>>> {
247        let mut tx = self.pool.begin().await?;
248
249        let leaves = query_as::<_, (i64, serde_json::Value)>(
250            "SELECT height, data FROM leaf WHERE height >= $1 AND height < $2 ORDER BY height",
251        )
252        .bind(start_height as i64)
253        .bind(end_height as i64)
254        .fetch_all(tx.as_mut())
255        .await?
256        .into_iter()
257        .map(|(_height, data)| serde_json::from_value(data))
258        .collect::<Result<Vec<_>, _>>()?;
259
260        tx.commit().await?;
261
262        Ok(leaves)
263    }
264
265    async fn insert_leaf(&self, leaf: LeafQueryData<SeqTypes>) -> Result<()> {
266        let mut tx = self.pool.begin().await?;
267
268        let height = leaf.height() as i64;
269        let hash = leaf.hash().to_string();
270        let block_hash = leaf.block_hash().to_string();
271        let payload_hash = leaf.payload_hash().to_string();
272        let data = serde_json::to_value(leaf)?;
273
274        tracing::debug!(height, hash, "inserting leaf");
275        let (id,): (i32,) = query_as(
276            "INSERT INTO leaf (height, hash, block_hash, payload_hash, data) VALUES ($1, $2, $3, \
277             $4, $5)
278                    ON CONFLICT (height) DO UPDATE SET id = excluded.id
279                    RETURNING id",
280        )
281        .bind(height)
282        .bind(&hash)
283        .bind(&block_hash)
284        .bind(&payload_hash)
285        .bind(data)
286        .fetch_one(tx.as_mut())
287        .await
288        .context("inserting new leaf")?;
289        tracing::debug!(height, hash, id, "inserted leaf");
290
291        // Delete the oldest leaves as necessary until the number of leaves stored does not exceed
292        // `num_leaves`.
293        let (num_leaves,): (u32,) = query_as("SELECT count(*) FROM leaf")
294            .fetch_one(tx.as_mut())
295            .await
296            .context("counting leaves")?;
297        let to_delete = num_leaves.saturating_sub(self.num_leaves);
298        if to_delete > 0 {
299            let (id_to_delete,): (i64,) =
300                query_as("SELECT id FROM leaf ORDER BY id LIMIT 1 OFFSET $1")
301                    .bind(to_delete - 1)
302                    .fetch_one(tx.as_mut())
303                    .await
304                    .context("finding timestamp for GC")?;
305            tracing::info!(id_to_delete, "garbage collecting {to_delete} leaves");
306            let res = query("DELETE FROM leaf WHERE id <= $1")
307                .bind(id_to_delete)
308                .execute(tx.as_mut())
309                .await
310                .context("deleting old leaves")?;
311            tracing::info!("deleted {} leaves", res.rows_affected());
312        }
313
314        tx.commit().await?;
315        Ok(())
316    }
317
318    async fn stake_table_lower_bound(
319        &self,
320        epoch: EpochNumber,
321    ) -> Result<Option<(EpochNumber, StakeTableState)>> {
322        let mut tx = self.pool.begin().await?;
323
324        let Some((epoch,)) = query_as::<_, (i64,)>(
325            "SELECT epoch FROM stake_table_epoch WHERE epoch <= $1 ORDER BY epoch DESC LIMIT 1",
326        )
327        .bind(*epoch as i64)
328        .fetch_optional(tx.as_mut())
329        .await
330        .context("loading epoch lower bound")?
331        else {
332            return Ok(None);
333        };
334
335        let validators = query_as::<_, (Value,)>(
336            "SELECT data FROM stake_table_validator WHERE epoch = $1 ORDER BY idx",
337        )
338        .bind(epoch)
339        .fetch(tx.as_mut())
340        .map_err(anyhow::Error::new)
341        .and_then(|(json,)| async move {
342            let validator: RegisteredValidator<PubKey> = serde_json::from_value(json)?;
343            Ok((validator.account, validator))
344        })
345        .try_collect()
346        .await
347        .context(format!("loading stake table for epoch {epoch}"))?;
348
349        let validator_exits =
350            query_as::<_, (String,)>("SELECT address FROM stake_table_exit WHERE epoch <= $1")
351                .bind(epoch)
352                .fetch(tx.as_mut())
353                .map_err(anyhow::Error::new)
354                .and_then(|(s,)| async move { Ok(Address::from_str(&s)?) })
355                .try_collect()
356                .await
357                .context(format!("loading validator exits for epoch {epoch}"))?;
358
359        let used_bls_keys =
360            query_as::<_, (String,)>("SELECT key FROM stake_table_bls_key WHERE epoch <= $1")
361                .bind(epoch)
362                .fetch(tx.as_mut())
363                .map_err(anyhow::Error::new)
364                .and_then(|(s,)| async move { Ok(PubKey::from_str(&s)?) })
365                .try_collect()
366                .await
367                .context(format!("loading BLS keys for epoch {epoch}"))?;
368
369        let used_schnorr_keys =
370            query_as::<_, (String,)>("SELECT key FROM stake_table_schnorr_key WHERE epoch <= $1")
371                .bind(epoch)
372                .fetch(tx.as_mut())
373                .map_err(anyhow::Error::new)
374                .and_then(|(s,)| async move { Ok(StateVerKey::from_str(&s)?) })
375                .try_collect()
376                .await
377                .context(format!("loading Schnorr keys for epoch {epoch}"))?;
378
379        Ok(Some((
380            EpochNumber::new(epoch as u64),
381            StakeTableState::new(
382                validators,
383                validator_exits,
384                used_bls_keys,
385                used_schnorr_keys,
386            ),
387        )))
388    }
389
390    async fn insert_stake_table(
391        &self,
392        epoch: EpochNumber,
393        stake_table: &StakeTableState,
394    ) -> Result<()> {
395        let mut tx = self.pool.begin().await?;
396
397        // Record that the stake table for this epoch is available.
398        let epoch = i64::try_from(*epoch).context("epoch overflow")?;
399        query("INSERT INTO stake_table_epoch (epoch) VALUES ($1)")
400            .bind(epoch)
401            .execute(tx.as_mut())
402            .await
403            .context(format!(
404                "recording stake table availability for epoch {epoch}"
405            ))?;
406
407        // Insert validators for the new stake table.
408        let validators = stake_table
409            .validators()
410            .values()
411            .cloned()
412            .map(serde_json::to_value)
413            .collect::<Result<Vec<_>, _>>()?;
414        QueryBuilder::new("INSERT INTO stake_table_validator (epoch, idx, data) ")
415            .push_values(validators.into_iter().enumerate(), |mut q, (i, data)| {
416                q.push_bind(epoch).push_bind(i as i64).push_bind(data);
417            })
418            .build()
419            .execute(tx.as_mut())
420            .await
421            .context(format!("inserting validators for epoch {epoch}"))?;
422
423        // Insert only newly used BLS keys.
424        QueryBuilder::new("INSERT INTO stake_table_bls_key (epoch, key) ")
425            .push_values(stake_table.used_bls_keys(), |mut q, key| {
426                q.push_bind(epoch).push_bind(key.to_string());
427            })
428            // If we insert keys out of order, make sure `epoch` reflects the earliest time when
429            // this key was added to the state.
430            .push(" ON CONFLICT (key) DO UPDATE SET epoch = min(epoch, excluded.epoch)")
431            .build()
432            .execute(tx.as_mut())
433            .await
434            .context(format!("inserting newly used BLS keys for epoch {epoch}"))?;
435
436        // Insert only newly used Schnorr keys.
437        QueryBuilder::new("INSERT INTO stake_table_schnorr_key (epoch, key) ")
438            .push_values(stake_table.used_schnorr_keys(), |mut q, key| {
439                q.push_bind(epoch).push_bind(key.to_string());
440            })
441            // If we insert keys out of order, make sure `epoch` reflects the earliest time when
442            // this key was added to the state.
443            .push(" ON CONFLICT (key) DO UPDATE SET epoch = min(epoch, excluded.epoch)")
444            .build()
445            .execute(tx.as_mut())
446            .await
447            .context(format!(
448                "inserting newly used Schnorr keys for epoch {epoch}"
449            ))?;
450
451        // Insert only the new validator exits.
452        if !stake_table.validator_exits().is_empty() {
453            QueryBuilder::new("INSERT INTO stake_table_exit (epoch, address) ")
454                .push_values(stake_table.validator_exits(), |mut q, address| {
455                    q.push_bind(epoch).push_bind(address.to_string());
456                })
457                // If we insert exits out of order, make sure `epoch` reflects the earliest time
458                // when this exit was added to the state.
459                .push(" ON CONFLICT (address) DO UPDATE SET epoch = min(epoch, excluded.epoch)")
460                .build()
461                .execute(tx.as_mut())
462                .await
463                .context(format!("inserting new validator exits for epoch {epoch}"))?;
464        }
465
466        // Delete the second oldest stake table if necessary to ensure the number of stake tables
467        // stored does not exceed `num_stake_tables`.
468        let (num_stake_tables,): (u32,) = query_as("SELECT count(*) FROM stake_table_epoch")
469            .fetch_one(tx.as_mut())
470            .await
471            .context("counting stake tables")?;
472        if num_stake_tables > self.num_stake_tables {
473            // We always delete the _second oldest_ stake table. We want to keep the oldest around
474            // because it is the hardest to catch up for if we need it again (we would have to go
475            // all the way back to genesis). The second oldest is the least likely to be used again
476            // after the oldest, while still being easy to replay if we do need it (because we can
477            // just replay from the cached oldest).
478            let (epoch_to_delete,): (i64,) =
479                query_as("SELECT epoch FROM stake_table_epoch ORDER BY epoch LIMIT 1 OFFSET 1")
480                    .fetch_one(tx.as_mut())
481                    .await
482                    .context("find second oldest epoch")?;
483            tracing::info!(epoch_to_delete, "garbage collecting stake table");
484
485            // Delete from the main epoch table. The corresponding rows from `stake_table_validator`
486            // will be deleted automatically by cascading. The corresponding rows in the BLS keys,
487            // Schnorr keys, and validator exits tables cannot be deleted, because those tables are
488            // cumulative over later epochs.
489            query("DELETE FROM stake_table_epoch WHERE epoch = $1")
490                .bind(epoch_to_delete)
491                .execute(tx.as_mut())
492                .await
493                .context("garbage collecting stake table")?;
494        }
495
496        tx.commit().await?;
497        Ok(())
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use pretty_assertions::assert_eq;
504    use versions::EPOCH_VERSION;
505
506    use super::*;
507    use crate::testing::{leaf_chain, random_validator};
508
509    #[tokio::test]
510    #[test_log::test]
511    async fn test_block_height() {
512        let db = SqliteStorage::default().await.unwrap();
513
514        // Test with empty db.
515        assert_eq!(db.block_height().await.unwrap(), 0);
516
517        // Test with nonconsecutive leaves.
518        let leaf = leaf_chain(100..101, EPOCH_VERSION).await.remove(0);
519        db.insert_leaf(leaf).await.unwrap();
520        assert_eq!(db.block_height().await.unwrap(), 101);
521    }
522
523    #[tokio::test]
524    #[test_log::test]
525    async fn test_leaf_upper_bound_exact() {
526        let db = SqliteStorage::default().await.unwrap();
527
528        let leaf = leaf_chain(0..1, EPOCH_VERSION).await.remove(0);
529        db.insert_leaf(leaf.clone()).await.unwrap();
530        assert_eq!(
531            db.leaf_upper_bound(LeafId::Number(0))
532                .await
533                .unwrap()
534                .unwrap(),
535            leaf
536        );
537        assert_eq!(
538            db.leaf_upper_bound(LeafId::Hash(leaf.hash()))
539                .await
540                .unwrap()
541                .unwrap(),
542            leaf
543        );
544        assert_eq!(
545            db.leaf_upper_bound(BlockId::Number(0))
546                .await
547                .unwrap()
548                .unwrap(),
549            leaf
550        );
551        assert_eq!(
552            db.leaf_upper_bound(BlockId::Hash(leaf.block_hash()))
553                .await
554                .unwrap()
555                .unwrap(),
556            leaf
557        );
558        assert_eq!(
559            db.leaf_upper_bound(BlockId::PayloadHash(leaf.payload_hash()))
560                .await
561                .unwrap()
562                .unwrap()
563                .payload_hash(),
564            leaf.payload_hash()
565        );
566    }
567
568    #[tokio::test]
569    #[test_log::test]
570    async fn test_leaf_upper_bound_loose() {
571        let db = SqliteStorage::default().await.unwrap();
572
573        let leaves = leaf_chain(0..=1, EPOCH_VERSION).await;
574        db.insert_leaf(leaves[1].clone()).await.unwrap();
575        assert_eq!(
576            db.leaf_upper_bound(LeafId::Number(0))
577                .await
578                .unwrap()
579                .unwrap(),
580            leaves[1]
581        );
582        // Searching by hash either gives an exact match or fails, there is no way of "upper
583        // bounding" a hash.
584        assert_eq!(
585            db.leaf_upper_bound(LeafId::Hash(leaves[0].hash()))
586                .await
587                .unwrap(),
588            None
589        );
590        assert_eq!(
591            db.leaf_upper_bound(BlockId::Hash(leaves[0].block_hash()))
592                .await
593                .unwrap(),
594            None
595        );
596    }
597
598    #[tokio::test]
599    #[test_log::test]
600    async fn test_leaf_upper_bound_least_upper_bound() {
601        let db = SqliteStorage::default().await.unwrap();
602
603        let leaves = leaf_chain(0..=2, EPOCH_VERSION).await;
604        db.insert_leaf(leaves[2].clone()).await.unwrap();
605        db.insert_leaf(leaves[1].clone()).await.unwrap();
606        assert_eq!(
607            db.leaf_upper_bound(LeafId::Number(0))
608                .await
609                .unwrap()
610                .unwrap(),
611            leaves[1]
612        );
613    }
614
615    #[tokio::test]
616    #[test_log::test]
617    async fn test_leaf_upper_bound_not_found() {
618        let db = SqliteStorage::default().await.unwrap();
619
620        let leaves = leaf_chain(0..=1, EPOCH_VERSION).await;
621        db.insert_leaf(leaves[0].clone()).await.unwrap();
622        assert_eq!(db.leaf_upper_bound(LeafId::Number(1)).await.unwrap(), None);
623        assert_eq!(
624            db.leaf_upper_bound(LeafId::Hash(leaves[1].hash()))
625                .await
626                .unwrap(),
627            None
628        );
629    }
630
631    #[tokio::test]
632    #[test_log::test]
633    async fn test_gc_last_inserted() {
634        let db = LightClientSqliteOptions {
635            num_leaves: 1,
636            ..Default::default()
637        }
638        .connect()
639        .await
640        .unwrap();
641
642        let leaves = leaf_chain(0..=1, EPOCH_VERSION).await;
643        db.insert_leaf(leaves[1].clone()).await.unwrap();
644        db.insert_leaf(leaves[0].clone()).await.unwrap();
645
646        assert_eq!(
647            db.leaf_upper_bound(LeafId::Number(0))
648                .await
649                .unwrap()
650                .unwrap(),
651            leaves[0]
652        );
653        assert_eq!(db.leaf_upper_bound(LeafId::Number(1)).await.unwrap(), None);
654    }
655
656    #[tokio::test]
657    #[test_log::test]
658    async fn test_gc_last_selected() {
659        let db = LightClientSqliteOptions {
660            num_leaves: 2,
661            ..Default::default()
662        }
663        .connect()
664        .await
665        .unwrap();
666
667        let leaves = leaf_chain(0..=2, EPOCH_VERSION).await;
668        db.insert_leaf(leaves[0].clone()).await.unwrap();
669        db.insert_leaf(leaves[1].clone()).await.unwrap();
670
671        // Select leaf 0, making it more recently used than leaf 1.
672        assert_eq!(
673            db.leaf_upper_bound(LeafId::Number(0))
674                .await
675                .unwrap()
676                .unwrap(),
677            leaves[0]
678        );
679
680        // Insert a third leaf, causing the least recently used (leaf 1) to be garbage collected.
681        db.insert_leaf(leaves[2].clone()).await.unwrap();
682
683        assert_eq!(
684            db.leaf_upper_bound(LeafId::Number(0))
685                .await
686                .unwrap()
687                .unwrap(),
688            leaves[0]
689        );
690        assert_eq!(
691            db.leaf_upper_bound(LeafId::Number(1))
692                .await
693                .unwrap()
694                .unwrap(),
695            leaves[2]
696        );
697        assert_eq!(
698            db.leaf_upper_bound(LeafId::Number(2))
699                .await
700                .unwrap()
701                .unwrap(),
702            leaves[2]
703        );
704    }
705
706    #[tokio::test]
707    #[test_log::test]
708    async fn test_get_leaves_in_range() {
709        let db = SqliteStorage::default().await.unwrap();
710
711        let leaves = leaf_chain(0..5, EPOCH_VERSION).await;
712        for leaf in &leaves {
713            db.insert_leaf(leaf.clone()).await.unwrap();
714        }
715
716        let fetched = db.get_leaves_in_range(1, 4).await.unwrap();
717        assert_eq!(fetched, leaves[1..4]);
718    }
719
720    #[tokio::test]
721    #[test_log::test]
722    async fn test_get_leaves_in_range_not_found() {
723        let db = SqliteStorage::default().await.unwrap();
724
725        let leaves = leaf_chain(0..3, EPOCH_VERSION).await;
726        for leaf in &leaves {
727            db.insert_leaf(leaf.clone()).await.unwrap();
728        }
729
730        let fetched = db.get_leaves_in_range(3, 5).await.unwrap();
731        assert!(fetched.is_empty());
732    }
733
734    #[tokio::test]
735    #[test_log::test]
736    async fn test_stake_table_lower_bound_exact() {
737        let db = SqliteStorage::default().await.unwrap();
738
739        let epoch = EpochNumber::new(1);
740        let state = random_stake_table();
741        db.insert_stake_table(epoch, &state).await.unwrap();
742        assert_eq!(
743            db.stake_table_lower_bound(epoch).await.unwrap().unwrap(),
744            (epoch, state)
745        );
746    }
747
748    #[tokio::test]
749    #[test_log::test]
750    async fn test_stake_table_lower_bound_loose() {
751        let db = SqliteStorage::default().await.unwrap();
752
753        let epoch = EpochNumber::new(1);
754        let state = random_stake_table();
755        db.insert_stake_table(epoch, &state).await.unwrap();
756        assert_eq!(
757            db.stake_table_lower_bound(epoch + 1)
758                .await
759                .unwrap()
760                .unwrap(),
761            (epoch, state)
762        );
763    }
764
765    #[tokio::test]
766    #[test_log::test]
767    async fn test_stake_table_lower_bound_greatest_lower_bound() {
768        let db = SqliteStorage::default().await.unwrap();
769
770        let state1 = random_stake_table();
771        let state2 = chain_stake_table(&state1);
772        db.insert_stake_table(EpochNumber::new(1), &state1)
773            .await
774            .unwrap();
775        db.insert_stake_table(EpochNumber::new(2), &state2)
776            .await
777            .unwrap();
778
779        assert_eq!(
780            db.stake_table_lower_bound(EpochNumber::new(3))
781                .await
782                .unwrap()
783                .unwrap(),
784            (EpochNumber::new(2), state2)
785        );
786    }
787
788    #[tokio::test]
789    #[test_log::test]
790    async fn test_stake_table_lower_bound_not_found() {
791        let db = SqliteStorage::default().await.unwrap();
792        db.insert_stake_table(EpochNumber::new(2), &random_stake_table())
793            .await
794            .unwrap();
795        assert_eq!(
796            db.stake_table_lower_bound(EpochNumber::new(1))
797                .await
798                .unwrap(),
799            None
800        );
801    }
802
803    #[tokio::test]
804    #[test_log::test]
805    async fn test_stake_table_gc() {
806        let db = LightClientSqliteOptions {
807            num_stake_tables: 2,
808            ..Default::default()
809        }
810        .connect()
811        .await
812        .unwrap();
813
814        let state1 = random_stake_table();
815        let state2 = chain_stake_table(&state1);
816        let state3 = chain_stake_table(&state2);
817        db.insert_stake_table(EpochNumber::new(1), &state1)
818            .await
819            .unwrap();
820        db.insert_stake_table(EpochNumber::new(2), &state2)
821            .await
822            .unwrap();
823        db.insert_stake_table(EpochNumber::new(3), &state3)
824            .await
825            .unwrap();
826
827        assert_eq!(
828            db.stake_table_lower_bound(EpochNumber::new(1))
829                .await
830                .unwrap()
831                .unwrap(),
832            (EpochNumber::new(1), state1.clone())
833        );
834        assert_eq!(
835            db.stake_table_lower_bound(EpochNumber::new(2))
836                .await
837                .unwrap()
838                .unwrap(),
839            (EpochNumber::new(1), state1)
840        );
841        assert_eq!(
842            db.stake_table_lower_bound(EpochNumber::new(3))
843                .await
844                .unwrap()
845                .unwrap(),
846            (EpochNumber::new(3), state3)
847        );
848    }
849
850    #[tokio::test]
851    #[test_log::test]
852    async fn test_stake_table_insert_out_of_order() {
853        let db = SqliteStorage::default().await.unwrap();
854
855        let state1 = random_stake_table();
856        let state2 = chain_stake_table(&state1);
857        db.insert_stake_table(EpochNumber::new(2), &state2)
858            .await
859            .unwrap();
860        db.insert_stake_table(EpochNumber::new(1), &state1)
861            .await
862            .unwrap();
863
864        assert_eq!(
865            db.stake_table_lower_bound(EpochNumber::new(1))
866                .await
867                .unwrap()
868                .unwrap(),
869            (EpochNumber::new(1), state1)
870        );
871        assert_eq!(
872            db.stake_table_lower_bound(EpochNumber::new(2))
873                .await
874                .unwrap()
875                .unwrap(),
876            (EpochNumber::new(2), state2)
877        );
878    }
879
880    /// Make a stake table state with all fields populated.
881    fn random_stake_table() -> StakeTableState {
882        let validator = random_validator();
883        let candidate: RegisteredValidator<PubKey> = validator.clone().into();
884        StakeTableState::new(
885            [(candidate.account, candidate.clone())]
886                .into_iter()
887                .collect(),
888            [Address::random()].into_iter().collect(),
889            [candidate.stake_table_key].into_iter().collect(),
890            [candidate.state_ver_key].into_iter().collect(),
891        )
892    }
893
894    /// Create a new stake table state which is a possible successor to the given state.
895    fn chain_stake_table(state: &StakeTableState) -> StakeTableState {
896        let new_validator = random_validator();
897        let new_candidate: RegisteredValidator<PubKey> = new_validator.clone().into();
898        let new_exit = Address::random();
899        StakeTableState::new(
900            state
901                .validators()
902                .values()
903                .chain([&new_candidate])
904                .map(|v| (v.account, v.clone()))
905                .collect(),
906            state
907                .validator_exits()
908                .iter()
909                .chain([&new_exit])
910                .cloned()
911                .collect(),
912            state
913                .used_bls_keys()
914                .iter()
915                .chain([&new_candidate.stake_table_key])
916                .cloned()
917                .collect(),
918            state
919                .used_schnorr_keys()
920                .iter()
921                .chain([&new_candidate.state_ver_key])
922                .cloned()
923                .collect(),
924        )
925    }
926}