1use std::{future::Future, path::PathBuf, str::FromStr};
2
3use alloy::primitives::Address;
4use anyhow::{Context, Result};
5use derive_more::{Display, From};
6use espresso_types::{PubKey, SeqTypes, StakeTableState, v0_3::RegisteredValidator};
7use futures::TryStreamExt;
8use hotshot_query_service::{
9 availability::{BlockId, LeafId, LeafQueryData},
10 types::HeightIndexed,
11};
12use hotshot_types::{data::EpochNumber, light_client::StateVerKey};
13use serde_json::Value;
14use sqlx::{
15 QueryBuilder, SqlitePool, query, query_as,
16 sqlite::{SqliteConnectOptions, SqlitePoolOptions},
17};
18
19#[derive(Clone, Copy, Debug, Display, From)]
21pub enum LeafRequest {
22 #[display("leaf {_0}")]
24 Leaf(LeafId<SeqTypes>),
25
26 #[display("header {_0}")]
28 Header(BlockId<SeqTypes>),
29}
30
31pub trait Storage: Sized + Send + Sync + 'static {
33 fn default() -> impl Send + Future<Output = Result<Self>>;
38
39 fn block_height(&self) -> impl Send + Future<Output = Result<u64>>;
46
47 fn leaf_upper_bound(
55 &self,
56 leaf: impl Into<LeafRequest> + Send,
57 ) -> impl Send + Future<Output = Result<Option<LeafQueryData<SeqTypes>>>>;
58
59 fn get_leaves_in_range(
61 &self,
62 start: u32,
63 end: u32,
64 ) -> impl Send + Future<Output = Result<Vec<LeafQueryData<SeqTypes>>>>;
65
66 fn insert_leaf(&self, leaf: LeafQueryData<SeqTypes>)
70 -> impl Send + Future<Output = Result<()>>;
71
72 fn stake_table_lower_bound(
77 &self,
78 epoch: EpochNumber,
79 ) -> impl Send + Future<Output = Result<Option<(EpochNumber, StakeTableState)>>>;
80
81 fn insert_stake_table(
85 &self,
86 epoch: EpochNumber,
87 stake_table: &StakeTableState,
88 ) -> impl Send + Future<Output = Result<()>>;
89}
90
91#[derive(Clone, Debug)]
92#[cfg_attr(feature = "clap", derive(clap::Parser))]
93pub struct LightClientSqliteOptions {
94 #[cfg_attr(
96 feature = "clap",
97 clap(
98 long = "light-client-db-num-connections",
99 env = "LIGHT_CLIENT_DB_NUM_CONNECTIONS",
100 default_value = "5",
101 )
102 )]
103 pub num_connections: u32,
104
105 #[cfg_attr(
107 feature = "clap",
108 clap(
109 long = "light-client-db-num-leaves",
110 env = "LIGHT_CLIENT_DB_NUM_LEAVES",
111 default_value = "100",
112 )
113 )]
114 pub num_leaves: u32,
115
116 #[cfg_attr(
118 feature = "clap",
119 clap(
120 long = "light-client-db-num-stake-tables",
121 env = "LIGHT_CLIENT_DB_NUM_STAKE_TABLES",
122 default_value = "100",
123 )
124 )]
125 pub num_stake_tables: u32,
126
127 #[cfg_attr(
132 feature = "clap",
133 clap(long = "light-client-db-path", env = "LIGHT_CLIENT_DB_PATH")
134 )]
135 pub lc_path: Option<PathBuf>,
136}
137
138impl Default for LightClientSqliteOptions {
139 fn default() -> Self {
140 Self {
141 num_connections: 5,
142 num_leaves: 100,
143 num_stake_tables: 100,
144 lc_path: None,
145 }
146 }
147}
148
149impl LightClientSqliteOptions {
150 pub async fn connect(self) -> Result<SqliteStorage> {
152 let path = match &self.lc_path {
153 Some(path) => path.to_str().context("invalid file path")?,
154 None => ":memory:",
155 };
156 let opt = SqliteConnectOptions::from_str(path)?.create_if_missing(true);
157 let pool = SqlitePoolOptions::default()
158 .max_connections(self.num_connections)
159 .connect_with(opt)
160 .await?;
161 sqlx::migrate!("./migrations").run(&pool).await?;
162
163 Ok(SqliteStorage {
164 pool,
165 num_leaves: self.num_leaves,
166 num_stake_tables: self.num_stake_tables,
167 })
168 }
169}
170
171#[derive(Clone, Debug)]
173pub struct SqliteStorage {
174 pool: SqlitePool,
175 num_leaves: u32,
176 num_stake_tables: u32,
177}
178
179impl Storage for SqliteStorage {
180 async fn default() -> Result<Self> {
181 LightClientSqliteOptions::default().connect().await
182 }
183
184 async fn block_height(&self) -> Result<u64> {
185 let mut tx = self.pool.begin().await?;
186 let (height,) = query_as("SELECT COALESCE(max(height) + 1, 0) FROM leaf")
187 .fetch_one(tx.as_mut())
188 .await?;
189 Ok(height)
190 }
191
192 async fn leaf_upper_bound(
193 &self,
194 id: impl Into<LeafRequest> + Send,
195 ) -> Result<Option<LeafQueryData<SeqTypes>>> {
196 let mut tx = self.pool.begin().await?;
197
198 let mut q = QueryBuilder::new("SELECT height, data FROM leaf WHERE ");
199 match id.into() {
200 LeafRequest::Leaf(LeafId::Number(n)) | LeafRequest::Header(BlockId::Number(n)) => {
201 q.push("height >= ")
202 .push_bind(n as i64)
203 .push("ORDER BY HEIGHT");
204 },
205 LeafRequest::Leaf(LeafId::Hash(h)) => {
206 q.push("hash = ").push_bind(h.to_string());
207 },
208 LeafRequest::Header(BlockId::Hash(h)) => {
209 q.push("block_hash = ").push_bind(h.to_string());
210 },
211 LeafRequest::Header(BlockId::PayloadHash(h)) => {
212 q.push("payload_hash = ")
213 .push_bind(h.to_string())
214 .push("ORDER BY height");
215 },
216 }
217 q.push(" LIMIT 1");
218
219 let Some((height, data)) = q
220 .build_query_as::<(i64, _)>()
221 .fetch_optional(tx.as_mut())
222 .await?
223 else {
224 return Ok(None);
225 };
226 let leaf = serde_json::from_value(data)?;
227
228 let (id,): (i32,) = query_as("SELECT max(id) + 1 FROM leaf")
230 .fetch_one(tx.as_mut())
231 .await?;
232 query("UPDATE leaf SET id = $1 WHERE height = $2")
233 .bind(id)
234 .bind(height)
235 .execute(tx.as_mut())
236 .await?;
237 tx.commit().await?;
238
239 Ok(Some(leaf))
240 }
241
242 async fn get_leaves_in_range(
243 &self,
244 start_height: u32,
245 end_height: u32,
246 ) -> Result<Vec<LeafQueryData<SeqTypes>>> {
247 let mut tx = self.pool.begin().await?;
248
249 let leaves = query_as::<_, (i64, serde_json::Value)>(
250 "SELECT height, data FROM leaf WHERE height >= $1 AND height < $2 ORDER BY height",
251 )
252 .bind(start_height as i64)
253 .bind(end_height as i64)
254 .fetch_all(tx.as_mut())
255 .await?
256 .into_iter()
257 .map(|(_height, data)| serde_json::from_value(data))
258 .collect::<Result<Vec<_>, _>>()?;
259
260 tx.commit().await?;
261
262 Ok(leaves)
263 }
264
265 async fn insert_leaf(&self, leaf: LeafQueryData<SeqTypes>) -> Result<()> {
266 let mut tx = self.pool.begin().await?;
267
268 let height = leaf.height() as i64;
269 let hash = leaf.hash().to_string();
270 let block_hash = leaf.block_hash().to_string();
271 let payload_hash = leaf.payload_hash().to_string();
272 let data = serde_json::to_value(leaf)?;
273
274 tracing::debug!(height, hash, "inserting leaf");
275 let (id,): (i32,) = query_as(
276 "INSERT INTO leaf (height, hash, block_hash, payload_hash, data) VALUES ($1, $2, $3, \
277 $4, $5)
278 ON CONFLICT (height) DO UPDATE SET id = excluded.id
279 RETURNING id",
280 )
281 .bind(height)
282 .bind(&hash)
283 .bind(&block_hash)
284 .bind(&payload_hash)
285 .bind(data)
286 .fetch_one(tx.as_mut())
287 .await
288 .context("inserting new leaf")?;
289 tracing::debug!(height, hash, id, "inserted leaf");
290
291 let (num_leaves,): (u32,) = query_as("SELECT count(*) FROM leaf")
294 .fetch_one(tx.as_mut())
295 .await
296 .context("counting leaves")?;
297 let to_delete = num_leaves.saturating_sub(self.num_leaves);
298 if to_delete > 0 {
299 let (id_to_delete,): (i64,) =
300 query_as("SELECT id FROM leaf ORDER BY id LIMIT 1 OFFSET $1")
301 .bind(to_delete - 1)
302 .fetch_one(tx.as_mut())
303 .await
304 .context("finding timestamp for GC")?;
305 tracing::info!(id_to_delete, "garbage collecting {to_delete} leaves");
306 let res = query("DELETE FROM leaf WHERE id <= $1")
307 .bind(id_to_delete)
308 .execute(tx.as_mut())
309 .await
310 .context("deleting old leaves")?;
311 tracing::info!("deleted {} leaves", res.rows_affected());
312 }
313
314 tx.commit().await?;
315 Ok(())
316 }
317
318 async fn stake_table_lower_bound(
319 &self,
320 epoch: EpochNumber,
321 ) -> Result<Option<(EpochNumber, StakeTableState)>> {
322 let mut tx = self.pool.begin().await?;
323
324 let Some((epoch,)) = query_as::<_, (i64,)>(
325 "SELECT epoch FROM stake_table_epoch WHERE epoch <= $1 ORDER BY epoch DESC LIMIT 1",
326 )
327 .bind(*epoch as i64)
328 .fetch_optional(tx.as_mut())
329 .await
330 .context("loading epoch lower bound")?
331 else {
332 return Ok(None);
333 };
334
335 let validators = query_as::<_, (Value,)>(
336 "SELECT data FROM stake_table_validator WHERE epoch = $1 ORDER BY idx",
337 )
338 .bind(epoch)
339 .fetch(tx.as_mut())
340 .map_err(anyhow::Error::new)
341 .and_then(|(json,)| async move {
342 let validator: RegisteredValidator<PubKey> = serde_json::from_value(json)?;
343 Ok((validator.account, validator))
344 })
345 .try_collect()
346 .await
347 .context(format!("loading stake table for epoch {epoch}"))?;
348
349 let validator_exits =
350 query_as::<_, (String,)>("SELECT address FROM stake_table_exit WHERE epoch <= $1")
351 .bind(epoch)
352 .fetch(tx.as_mut())
353 .map_err(anyhow::Error::new)
354 .and_then(|(s,)| async move { Ok(Address::from_str(&s)?) })
355 .try_collect()
356 .await
357 .context(format!("loading validator exits for epoch {epoch}"))?;
358
359 let used_bls_keys =
360 query_as::<_, (String,)>("SELECT key FROM stake_table_bls_key WHERE epoch <= $1")
361 .bind(epoch)
362 .fetch(tx.as_mut())
363 .map_err(anyhow::Error::new)
364 .and_then(|(s,)| async move { Ok(PubKey::from_str(&s)?) })
365 .try_collect()
366 .await
367 .context(format!("loading BLS keys for epoch {epoch}"))?;
368
369 let used_schnorr_keys =
370 query_as::<_, (String,)>("SELECT key FROM stake_table_schnorr_key WHERE epoch <= $1")
371 .bind(epoch)
372 .fetch(tx.as_mut())
373 .map_err(anyhow::Error::new)
374 .and_then(|(s,)| async move { Ok(StateVerKey::from_str(&s)?) })
375 .try_collect()
376 .await
377 .context(format!("loading Schnorr keys for epoch {epoch}"))?;
378
379 Ok(Some((
380 EpochNumber::new(epoch as u64),
381 StakeTableState::new(
382 validators,
383 validator_exits,
384 used_bls_keys,
385 used_schnorr_keys,
386 ),
387 )))
388 }
389
390 async fn insert_stake_table(
391 &self,
392 epoch: EpochNumber,
393 stake_table: &StakeTableState,
394 ) -> Result<()> {
395 let mut tx = self.pool.begin().await?;
396
397 let epoch = i64::try_from(*epoch).context("epoch overflow")?;
399 query("INSERT INTO stake_table_epoch (epoch) VALUES ($1)")
400 .bind(epoch)
401 .execute(tx.as_mut())
402 .await
403 .context(format!(
404 "recording stake table availability for epoch {epoch}"
405 ))?;
406
407 let validators = stake_table
409 .validators()
410 .values()
411 .cloned()
412 .map(serde_json::to_value)
413 .collect::<Result<Vec<_>, _>>()?;
414 QueryBuilder::new("INSERT INTO stake_table_validator (epoch, idx, data) ")
415 .push_values(validators.into_iter().enumerate(), |mut q, (i, data)| {
416 q.push_bind(epoch).push_bind(i as i64).push_bind(data);
417 })
418 .build()
419 .execute(tx.as_mut())
420 .await
421 .context(format!("inserting validators for epoch {epoch}"))?;
422
423 QueryBuilder::new("INSERT INTO stake_table_bls_key (epoch, key) ")
425 .push_values(stake_table.used_bls_keys(), |mut q, key| {
426 q.push_bind(epoch).push_bind(key.to_string());
427 })
428 .push(" ON CONFLICT (key) DO UPDATE SET epoch = min(epoch, excluded.epoch)")
431 .build()
432 .execute(tx.as_mut())
433 .await
434 .context(format!("inserting newly used BLS keys for epoch {epoch}"))?;
435
436 QueryBuilder::new("INSERT INTO stake_table_schnorr_key (epoch, key) ")
438 .push_values(stake_table.used_schnorr_keys(), |mut q, key| {
439 q.push_bind(epoch).push_bind(key.to_string());
440 })
441 .push(" ON CONFLICT (key) DO UPDATE SET epoch = min(epoch, excluded.epoch)")
444 .build()
445 .execute(tx.as_mut())
446 .await
447 .context(format!(
448 "inserting newly used Schnorr keys for epoch {epoch}"
449 ))?;
450
451 if !stake_table.validator_exits().is_empty() {
453 QueryBuilder::new("INSERT INTO stake_table_exit (epoch, address) ")
454 .push_values(stake_table.validator_exits(), |mut q, address| {
455 q.push_bind(epoch).push_bind(address.to_string());
456 })
457 .push(" ON CONFLICT (address) DO UPDATE SET epoch = min(epoch, excluded.epoch)")
460 .build()
461 .execute(tx.as_mut())
462 .await
463 .context(format!("inserting new validator exits for epoch {epoch}"))?;
464 }
465
466 let (num_stake_tables,): (u32,) = query_as("SELECT count(*) FROM stake_table_epoch")
469 .fetch_one(tx.as_mut())
470 .await
471 .context("counting stake tables")?;
472 if num_stake_tables > self.num_stake_tables {
473 let (epoch_to_delete,): (i64,) =
479 query_as("SELECT epoch FROM stake_table_epoch ORDER BY epoch LIMIT 1 OFFSET 1")
480 .fetch_one(tx.as_mut())
481 .await
482 .context("find second oldest epoch")?;
483 tracing::info!(epoch_to_delete, "garbage collecting stake table");
484
485 query("DELETE FROM stake_table_epoch WHERE epoch = $1")
490 .bind(epoch_to_delete)
491 .execute(tx.as_mut())
492 .await
493 .context("garbage collecting stake table")?;
494 }
495
496 tx.commit().await?;
497 Ok(())
498 }
499}
500
501#[cfg(test)]
502mod test {
503 use pretty_assertions::assert_eq;
504 use versions::EPOCH_VERSION;
505
506 use super::*;
507 use crate::testing::{leaf_chain, random_validator};
508
509 #[tokio::test]
510 #[test_log::test]
511 async fn test_block_height() {
512 let db = SqliteStorage::default().await.unwrap();
513
514 assert_eq!(db.block_height().await.unwrap(), 0);
516
517 let leaf = leaf_chain(100..101, EPOCH_VERSION).await.remove(0);
519 db.insert_leaf(leaf).await.unwrap();
520 assert_eq!(db.block_height().await.unwrap(), 101);
521 }
522
523 #[tokio::test]
524 #[test_log::test]
525 async fn test_leaf_upper_bound_exact() {
526 let db = SqliteStorage::default().await.unwrap();
527
528 let leaf = leaf_chain(0..1, EPOCH_VERSION).await.remove(0);
529 db.insert_leaf(leaf.clone()).await.unwrap();
530 assert_eq!(
531 db.leaf_upper_bound(LeafId::Number(0))
532 .await
533 .unwrap()
534 .unwrap(),
535 leaf
536 );
537 assert_eq!(
538 db.leaf_upper_bound(LeafId::Hash(leaf.hash()))
539 .await
540 .unwrap()
541 .unwrap(),
542 leaf
543 );
544 assert_eq!(
545 db.leaf_upper_bound(BlockId::Number(0))
546 .await
547 .unwrap()
548 .unwrap(),
549 leaf
550 );
551 assert_eq!(
552 db.leaf_upper_bound(BlockId::Hash(leaf.block_hash()))
553 .await
554 .unwrap()
555 .unwrap(),
556 leaf
557 );
558 assert_eq!(
559 db.leaf_upper_bound(BlockId::PayloadHash(leaf.payload_hash()))
560 .await
561 .unwrap()
562 .unwrap()
563 .payload_hash(),
564 leaf.payload_hash()
565 );
566 }
567
568 #[tokio::test]
569 #[test_log::test]
570 async fn test_leaf_upper_bound_loose() {
571 let db = SqliteStorage::default().await.unwrap();
572
573 let leaves = leaf_chain(0..=1, EPOCH_VERSION).await;
574 db.insert_leaf(leaves[1].clone()).await.unwrap();
575 assert_eq!(
576 db.leaf_upper_bound(LeafId::Number(0))
577 .await
578 .unwrap()
579 .unwrap(),
580 leaves[1]
581 );
582 assert_eq!(
585 db.leaf_upper_bound(LeafId::Hash(leaves[0].hash()))
586 .await
587 .unwrap(),
588 None
589 );
590 assert_eq!(
591 db.leaf_upper_bound(BlockId::Hash(leaves[0].block_hash()))
592 .await
593 .unwrap(),
594 None
595 );
596 }
597
598 #[tokio::test]
599 #[test_log::test]
600 async fn test_leaf_upper_bound_least_upper_bound() {
601 let db = SqliteStorage::default().await.unwrap();
602
603 let leaves = leaf_chain(0..=2, EPOCH_VERSION).await;
604 db.insert_leaf(leaves[2].clone()).await.unwrap();
605 db.insert_leaf(leaves[1].clone()).await.unwrap();
606 assert_eq!(
607 db.leaf_upper_bound(LeafId::Number(0))
608 .await
609 .unwrap()
610 .unwrap(),
611 leaves[1]
612 );
613 }
614
615 #[tokio::test]
616 #[test_log::test]
617 async fn test_leaf_upper_bound_not_found() {
618 let db = SqliteStorage::default().await.unwrap();
619
620 let leaves = leaf_chain(0..=1, EPOCH_VERSION).await;
621 db.insert_leaf(leaves[0].clone()).await.unwrap();
622 assert_eq!(db.leaf_upper_bound(LeafId::Number(1)).await.unwrap(), None);
623 assert_eq!(
624 db.leaf_upper_bound(LeafId::Hash(leaves[1].hash()))
625 .await
626 .unwrap(),
627 None
628 );
629 }
630
631 #[tokio::test]
632 #[test_log::test]
633 async fn test_gc_last_inserted() {
634 let db = LightClientSqliteOptions {
635 num_leaves: 1,
636 ..Default::default()
637 }
638 .connect()
639 .await
640 .unwrap();
641
642 let leaves = leaf_chain(0..=1, EPOCH_VERSION).await;
643 db.insert_leaf(leaves[1].clone()).await.unwrap();
644 db.insert_leaf(leaves[0].clone()).await.unwrap();
645
646 assert_eq!(
647 db.leaf_upper_bound(LeafId::Number(0))
648 .await
649 .unwrap()
650 .unwrap(),
651 leaves[0]
652 );
653 assert_eq!(db.leaf_upper_bound(LeafId::Number(1)).await.unwrap(), None);
654 }
655
656 #[tokio::test]
657 #[test_log::test]
658 async fn test_gc_last_selected() {
659 let db = LightClientSqliteOptions {
660 num_leaves: 2,
661 ..Default::default()
662 }
663 .connect()
664 .await
665 .unwrap();
666
667 let leaves = leaf_chain(0..=2, EPOCH_VERSION).await;
668 db.insert_leaf(leaves[0].clone()).await.unwrap();
669 db.insert_leaf(leaves[1].clone()).await.unwrap();
670
671 assert_eq!(
673 db.leaf_upper_bound(LeafId::Number(0))
674 .await
675 .unwrap()
676 .unwrap(),
677 leaves[0]
678 );
679
680 db.insert_leaf(leaves[2].clone()).await.unwrap();
682
683 assert_eq!(
684 db.leaf_upper_bound(LeafId::Number(0))
685 .await
686 .unwrap()
687 .unwrap(),
688 leaves[0]
689 );
690 assert_eq!(
691 db.leaf_upper_bound(LeafId::Number(1))
692 .await
693 .unwrap()
694 .unwrap(),
695 leaves[2]
696 );
697 assert_eq!(
698 db.leaf_upper_bound(LeafId::Number(2))
699 .await
700 .unwrap()
701 .unwrap(),
702 leaves[2]
703 );
704 }
705
706 #[tokio::test]
707 #[test_log::test]
708 async fn test_get_leaves_in_range() {
709 let db = SqliteStorage::default().await.unwrap();
710
711 let leaves = leaf_chain(0..5, EPOCH_VERSION).await;
712 for leaf in &leaves {
713 db.insert_leaf(leaf.clone()).await.unwrap();
714 }
715
716 let fetched = db.get_leaves_in_range(1, 4).await.unwrap();
717 assert_eq!(fetched, leaves[1..4]);
718 }
719
720 #[tokio::test]
721 #[test_log::test]
722 async fn test_get_leaves_in_range_not_found() {
723 let db = SqliteStorage::default().await.unwrap();
724
725 let leaves = leaf_chain(0..3, EPOCH_VERSION).await;
726 for leaf in &leaves {
727 db.insert_leaf(leaf.clone()).await.unwrap();
728 }
729
730 let fetched = db.get_leaves_in_range(3, 5).await.unwrap();
731 assert!(fetched.is_empty());
732 }
733
734 #[tokio::test]
735 #[test_log::test]
736 async fn test_stake_table_lower_bound_exact() {
737 let db = SqliteStorage::default().await.unwrap();
738
739 let epoch = EpochNumber::new(1);
740 let state = random_stake_table();
741 db.insert_stake_table(epoch, &state).await.unwrap();
742 assert_eq!(
743 db.stake_table_lower_bound(epoch).await.unwrap().unwrap(),
744 (epoch, state)
745 );
746 }
747
748 #[tokio::test]
749 #[test_log::test]
750 async fn test_stake_table_lower_bound_loose() {
751 let db = SqliteStorage::default().await.unwrap();
752
753 let epoch = EpochNumber::new(1);
754 let state = random_stake_table();
755 db.insert_stake_table(epoch, &state).await.unwrap();
756 assert_eq!(
757 db.stake_table_lower_bound(epoch + 1)
758 .await
759 .unwrap()
760 .unwrap(),
761 (epoch, state)
762 );
763 }
764
765 #[tokio::test]
766 #[test_log::test]
767 async fn test_stake_table_lower_bound_greatest_lower_bound() {
768 let db = SqliteStorage::default().await.unwrap();
769
770 let state1 = random_stake_table();
771 let state2 = chain_stake_table(&state1);
772 db.insert_stake_table(EpochNumber::new(1), &state1)
773 .await
774 .unwrap();
775 db.insert_stake_table(EpochNumber::new(2), &state2)
776 .await
777 .unwrap();
778
779 assert_eq!(
780 db.stake_table_lower_bound(EpochNumber::new(3))
781 .await
782 .unwrap()
783 .unwrap(),
784 (EpochNumber::new(2), state2)
785 );
786 }
787
788 #[tokio::test]
789 #[test_log::test]
790 async fn test_stake_table_lower_bound_not_found() {
791 let db = SqliteStorage::default().await.unwrap();
792 db.insert_stake_table(EpochNumber::new(2), &random_stake_table())
793 .await
794 .unwrap();
795 assert_eq!(
796 db.stake_table_lower_bound(EpochNumber::new(1))
797 .await
798 .unwrap(),
799 None
800 );
801 }
802
803 #[tokio::test]
804 #[test_log::test]
805 async fn test_stake_table_gc() {
806 let db = LightClientSqliteOptions {
807 num_stake_tables: 2,
808 ..Default::default()
809 }
810 .connect()
811 .await
812 .unwrap();
813
814 let state1 = random_stake_table();
815 let state2 = chain_stake_table(&state1);
816 let state3 = chain_stake_table(&state2);
817 db.insert_stake_table(EpochNumber::new(1), &state1)
818 .await
819 .unwrap();
820 db.insert_stake_table(EpochNumber::new(2), &state2)
821 .await
822 .unwrap();
823 db.insert_stake_table(EpochNumber::new(3), &state3)
824 .await
825 .unwrap();
826
827 assert_eq!(
828 db.stake_table_lower_bound(EpochNumber::new(1))
829 .await
830 .unwrap()
831 .unwrap(),
832 (EpochNumber::new(1), state1.clone())
833 );
834 assert_eq!(
835 db.stake_table_lower_bound(EpochNumber::new(2))
836 .await
837 .unwrap()
838 .unwrap(),
839 (EpochNumber::new(1), state1)
840 );
841 assert_eq!(
842 db.stake_table_lower_bound(EpochNumber::new(3))
843 .await
844 .unwrap()
845 .unwrap(),
846 (EpochNumber::new(3), state3)
847 );
848 }
849
850 #[tokio::test]
851 #[test_log::test]
852 async fn test_stake_table_insert_out_of_order() {
853 let db = SqliteStorage::default().await.unwrap();
854
855 let state1 = random_stake_table();
856 let state2 = chain_stake_table(&state1);
857 db.insert_stake_table(EpochNumber::new(2), &state2)
858 .await
859 .unwrap();
860 db.insert_stake_table(EpochNumber::new(1), &state1)
861 .await
862 .unwrap();
863
864 assert_eq!(
865 db.stake_table_lower_bound(EpochNumber::new(1))
866 .await
867 .unwrap()
868 .unwrap(),
869 (EpochNumber::new(1), state1)
870 );
871 assert_eq!(
872 db.stake_table_lower_bound(EpochNumber::new(2))
873 .await
874 .unwrap()
875 .unwrap(),
876 (EpochNumber::new(2), state2)
877 );
878 }
879
880 fn random_stake_table() -> StakeTableState {
882 let validator = random_validator();
883 let candidate: RegisteredValidator<PubKey> = validator.clone().into();
884 StakeTableState::new(
885 [(candidate.account, candidate.clone())]
886 .into_iter()
887 .collect(),
888 [Address::random()].into_iter().collect(),
889 [candidate.stake_table_key].into_iter().collect(),
890 [candidate.state_ver_key].into_iter().collect(),
891 )
892 }
893
894 fn chain_stake_table(state: &StakeTableState) -> StakeTableState {
896 let new_validator = random_validator();
897 let new_candidate: RegisteredValidator<PubKey> = new_validator.clone().into();
898 let new_exit = Address::random();
899 StakeTableState::new(
900 state
901 .validators()
902 .values()
903 .chain([&new_candidate])
904 .map(|v| (v.account, v.clone()))
905 .collect(),
906 state
907 .validator_exits()
908 .iter()
909 .chain([&new_exit])
910 .cloned()
911 .collect(),
912 state
913 .used_bls_keys()
914 .iter()
915 .chain([&new_candidate.stake_table_key])
916 .cloned()
917 .collect(),
918 state
919 .used_schnorr_keys()
920 .iter()
921 .chain([&new_candidate.state_ver_key])
922 .cloned()
923 .collect(),
924 )
925 }
926}