1#![cfg(feature = "sql-data-source")]
14use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration};
15
16use async_trait::async_trait;
17use chrono::Utc;
18#[cfg(not(feature = "embedded-db"))]
19use futures::future::FutureExt;
20use hotshot_types::{
21 data::VidShare,
22 traits::{metrics::Metrics, node_implementation::NodeType},
23};
24use itertools::Itertools;
25use log::LevelFilter;
26#[cfg(not(feature = "embedded-db"))]
27use sqlx::postgres::{PgConnectOptions, PgSslMode};
28#[cfg(feature = "embedded-db")]
29use sqlx::sqlite::SqliteConnectOptions;
30use sqlx::{
31 ConnectOptions, Row,
32 pool::{Pool, PoolOptions},
33};
34
35use crate::{
36 Header, QueryError, QueryResult,
37 availability::{QueryableHeader, QueryablePayload, VidCommonMetadata, VidCommonQueryData},
38 data_source::{
39 VersionedDataSource,
40 storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig},
41 update::Transaction as _,
42 },
43 metrics::PrometheusMetrics,
44 node::BlockId,
45 status::HasMetrics,
46};
47pub extern crate sqlx;
48pub use sqlx::{Database, Sqlite};
49
50mod db;
51mod migrate;
52mod queries;
53mod transaction;
54
55pub use anyhow::Error;
56pub use db::*;
57pub use include_dir::include_dir;
58pub use queries::QueryBuilder;
59pub use refinery::Migration;
60pub use transaction::*;
61
62use self::{migrate::Migrator, transaction::PoolMetrics};
63use super::{AvailabilityStorage, NodeStorage};
64pub use crate::include_migrations;
68
69#[macro_export]
112macro_rules! include_migrations {
113 ($dir:tt) => {
114 $crate::data_source::storage::sql::include_dir!($dir)
115 .files()
116 .map(|file| {
117 let path = file.path();
118 let name = path
119 .file_name()
120 .and_then(std::ffi::OsStr::to_str)
121 .unwrap_or_else(|| {
122 panic!(
123 "migration file {} must have a non-empty UTF-8 name",
124 path.display()
125 )
126 });
127 let sql = file
128 .contents_utf8()
129 .unwrap_or_else(|| panic!("migration file {name} must use UTF-8 encoding"));
130 $crate::data_source::storage::sql::Migration::unapplied(name, sql)
131 .expect("invalid migration")
132 })
133 };
134}
135
136pub fn default_migrations() -> Vec<Migration> {
138 #[cfg(not(feature = "embedded-db"))]
139 let mut migrations =
140 include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::<Vec<_>>();
141
142 #[cfg(feature = "embedded-db")]
143 let mut migrations =
144 include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::<Vec<_>>();
145
146 validate_migrations(&mut migrations).expect("default migrations are invalid");
148
149 for m in &migrations {
152 if m.version() <= 30 {
153 assert!(
157 m.version() > 0 && m.version() % 10 == 0,
158 "legacy default migration version {} is not a positive multiple of 10",
159 m.version()
160 );
161 } else {
162 assert!(
163 m.version() % 100 == 0,
164 "default migration version {} is not a multiple of 100",
165 m.version()
166 );
167 }
168 }
169
170 migrations
171}
172
173fn validate_migrations(migrations: &mut [Migration]) -> Result<(), Error> {
178 migrations.sort_by_key(|m| m.version());
179
180 for (prev, next) in migrations.iter().zip(migrations.iter().skip(1)) {
182 if next <= prev {
183 return Err(Error::msg(format!(
184 "migration versions are not strictly increasing ({prev}->{next})"
185 )));
186 }
187 }
188
189 Ok(())
190}
191
192fn add_custom_migrations(
199 default: impl IntoIterator<Item = Migration>,
200 custom: impl IntoIterator<Item = Migration>,
201) -> impl Iterator<Item = Migration> {
202 default
203 .into_iter()
204 .merge_join_by(custom, |l, r| l.version().cmp(&r.version()))
206 .map(|pair| pair.reduce(|_, custom| custom))
209}
210
211#[derive(Clone)]
212pub struct Config {
213 #[cfg(feature = "embedded-db")]
214 db_opt: SqliteConnectOptions,
215
216 #[cfg(not(feature = "embedded-db"))]
217 db_opt: PgConnectOptions,
218
219 pool_opt: PoolOptions<Db>,
220
221 #[cfg(not(feature = "embedded-db"))]
223 pool_opt_query: PoolOptions<Db>,
224
225 #[cfg(not(feature = "embedded-db"))]
226 schema: String,
227 reset: bool,
228 migrations: Vec<Migration>,
229 no_migrations: bool,
230 pruner_cfg: Option<PrunerCfg>,
231 archive: bool,
232 pool: Option<Pool<Db>>,
233}
234
235#[cfg(not(feature = "embedded-db"))]
236impl Default for Config {
237 fn default() -> Self {
238 PgConnectOptions::default()
239 .username("postgres")
240 .password("password")
241 .host("localhost")
242 .port(5432)
243 .into()
244 }
245}
246
247#[cfg(feature = "embedded-db")]
248impl Default for Config {
249 fn default() -> Self {
250 SqliteConnectOptions::default()
251 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
252 .busy_timeout(Duration::from_secs(30))
253 .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
254 .create_if_missing(true)
255 .into()
256 }
257}
258
259#[cfg(feature = "embedded-db")]
260impl From<SqliteConnectOptions> for Config {
261 fn from(db_opt: SqliteConnectOptions) -> Self {
262 Self {
263 db_opt,
264 pool_opt: PoolOptions::default(),
265 reset: false,
266 migrations: vec![],
267 no_migrations: false,
268 pruner_cfg: None,
269 archive: false,
270 pool: None,
271 }
272 }
273}
274
275#[cfg(not(feature = "embedded-db"))]
276impl From<PgConnectOptions> for Config {
277 fn from(db_opt: PgConnectOptions) -> Self {
278 Self {
279 db_opt,
280 pool_opt: PoolOptions::default(),
281 pool_opt_query: PoolOptions::default(),
282 schema: "hotshot".into(),
283 reset: false,
284 migrations: vec![],
285 no_migrations: false,
286 pruner_cfg: None,
287 archive: false,
288 pool: None,
289 }
290 }
291}
292
293#[cfg(not(feature = "embedded-db"))]
294impl FromStr for Config {
295 type Err = <PgConnectOptions as FromStr>::Err;
296
297 fn from_str(s: &str) -> Result<Self, Self::Err> {
298 Ok(PgConnectOptions::from_str(s)?.into())
299 }
300}
301
302#[cfg(feature = "embedded-db")]
303impl FromStr for Config {
304 type Err = <SqliteConnectOptions as FromStr>::Err;
305
306 fn from_str(s: &str) -> Result<Self, Self::Err> {
307 Ok(SqliteConnectOptions::from_str(s)?.into())
308 }
309}
310
311#[cfg(feature = "embedded-db")]
312impl Config {
313 pub fn busy_timeout(mut self, timeout: Duration) -> Self {
314 self.db_opt = self.db_opt.busy_timeout(timeout);
315 self
316 }
317
318 pub fn db_path(mut self, path: std::path::PathBuf) -> Self {
319 self.db_opt = self.db_opt.filename(path);
320 self
321 }
322}
323
324#[cfg(not(feature = "embedded-db"))]
325impl Config {
326 pub fn host(mut self, host: impl Into<String>) -> Self {
330 self.db_opt = self.db_opt.host(&host.into());
331 self
332 }
333
334 pub fn port(mut self, port: u16) -> Self {
338 self.db_opt = self.db_opt.port(port);
339 self
340 }
341
342 pub fn user(mut self, user: &str) -> Self {
344 self.db_opt = self.db_opt.username(user);
345 self
346 }
347
348 pub fn password(mut self, password: &str) -> Self {
350 self.db_opt = self.db_opt.password(password);
351 self
352 }
353
354 pub fn database(mut self, database: &str) -> Self {
356 self.db_opt = self.db_opt.database(database);
357 self
358 }
359
360 pub fn tls(mut self) -> Self {
366 self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require);
367 self
368 }
369
370 pub fn schema(mut self, schema: impl Into<String>) -> Self {
374 self.schema = schema.into();
375 self
376 }
377}
378
379impl Config {
380 pub fn pool(mut self, pool: Pool<Db>) -> Self {
383 self.pool = Some(pool);
384 self
385 }
386
387 pub fn reset_schema(mut self) -> Self {
398 self.reset = true;
399 self
400 }
401
402 pub fn migrations(mut self, migrations: impl IntoIterator<Item = Migration>) -> Self {
404 self.migrations.extend(migrations);
405 self
406 }
407
408 pub fn no_migrations(mut self) -> Self {
410 self.no_migrations = true;
411 self
412 }
413
414 pub fn pruner_cfg(mut self, cfg: PrunerCfg) -> Result<Self, Error> {
418 cfg.validate()?;
419 self.pruner_cfg = Some(cfg);
420 self.archive = false;
421 Ok(self)
422 }
423
424 pub fn archive(mut self) -> Self {
433 self.pruner_cfg = None;
434 self.archive = true;
435 self
436 }
437
438 pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
443 self.pool_opt = self.pool_opt.idle_timeout(Some(timeout));
444
445 #[cfg(not(feature = "embedded-db"))]
446 {
447 self.pool_opt_query = self.pool_opt_query.idle_timeout(Some(timeout));
448 }
449
450 self
451 }
452
453 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
460 self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
461
462 #[cfg(not(feature = "embedded-db"))]
463 {
464 self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
465 }
466
467 self
468 }
469
470 pub fn min_connections(mut self, min: u32) -> Self {
476 self.pool_opt = self.pool_opt.min_connections(min);
477 self
478 }
479
480 #[cfg(not(feature = "embedded-db"))]
481 pub fn query_min_connections(mut self, min: u32) -> Self {
482 self.pool_opt_query = self.pool_opt_query.min_connections(min);
483 self
484 }
485
486 pub fn max_connections(mut self, max: u32) -> Self {
491 self.pool_opt = self.pool_opt.max_connections(max);
492 self
493 }
494
495 #[cfg(not(feature = "embedded-db"))]
496 pub fn query_max_connections(mut self, max: u32) -> Self {
497 self.pool_opt_query = self.pool_opt_query.max_connections(max);
498 self
499 }
500
501 pub fn slow_statement_threshold(mut self, threshold: Duration) -> Self {
505 self.db_opt = self
506 .db_opt
507 .log_slow_statements(LevelFilter::Warn, threshold);
508 self
509 }
510
511 #[cfg(not(feature = "embedded-db"))]
515 pub fn statement_timeout(mut self, timeout: Duration) -> Self {
516 let timeout_ms = timeout.as_millis();
519 self.db_opt = self
520 .db_opt
521 .options([("statement_timeout", timeout_ms.to_string())]);
522 self
523 }
524
525 #[cfg(feature = "embedded-db")]
527 pub fn statement_timeout(self, _timeout: Duration) -> Self {
528 self
529 }
530}
531
532#[derive(Clone, Debug)]
534pub struct SqlStorage {
535 pool: Pool<Db>,
536 metrics: PrometheusMetrics,
537 pool_metrics: PoolMetrics,
538 pruner_cfg: Option<PrunerCfg>,
539}
540
541#[derive(Debug, Default)]
542pub struct Pruner {
543 pruned_height: Option<u64>,
544 target_height: Option<u64>,
545 minimum_retention_height: Option<u64>,
546}
547
548#[derive(PartialEq)]
549pub enum StorageConnectionType {
550 Sequencer,
551 Query,
552}
553
554impl SqlStorage {
555 pub fn pool(&self) -> Pool<Db> {
556 self.pool.clone()
557 }
558
559 #[allow(unused_variables)]
561 pub async fn connect(
562 mut config: Config,
563 connection_type: StorageConnectionType,
564 ) -> Result<Self, Error> {
565 let metrics = PrometheusMetrics::default();
566 let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into()));
567
568 #[cfg(feature = "embedded-db")]
569 let pool = config.pool_opt.clone();
570 #[cfg(not(feature = "embedded-db"))]
571 let pool = match connection_type {
572 StorageConnectionType::Sequencer => config.pool_opt.clone(),
573 StorageConnectionType::Query => config.pool_opt_query.clone(),
574 };
575
576 let pruner_cfg = config.pruner_cfg;
577
578 if cfg!(feature = "embedded-db") || connection_type == StorageConnectionType::Sequencer {
580 if let Some(pool) = config.pool {
582 return Ok(Self {
583 metrics,
584 pool_metrics,
585 pool,
586 pruner_cfg,
587 });
588 }
589 } else if config.pool.is_some() {
590 tracing::info!("not reusing existing pool for query connection");
591 }
592
593 #[cfg(not(feature = "embedded-db"))]
594 let schema = config.schema.clone();
595 #[cfg(not(feature = "embedded-db"))]
596 let pool = pool.after_connect(move |conn, _| {
597 let schema = config.schema.clone();
598 async move {
599 query(&format!("SET search_path TO {schema}"))
600 .execute(conn)
601 .await?;
602 Ok(())
603 }
604 .boxed()
605 });
606
607 #[cfg(feature = "embedded-db")]
608 if config.reset {
609 std::fs::remove_file(config.db_opt.get_filename())?;
610 }
611
612 let pool = pool.connect_with(config.db_opt).await?;
613
614 let mut conn = pool.acquire().await?;
616
617 #[cfg(not(feature = "embedded-db"))]
619 query("SET statement_timeout = 0")
620 .execute(conn.as_mut())
621 .await?;
622
623 #[cfg(not(feature = "embedded-db"))]
624 if config.reset {
625 query(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"))
626 .execute(conn.as_mut())
627 .await?;
628 }
629
630 #[cfg(not(feature = "embedded-db"))]
631 query(&format!("CREATE SCHEMA IF NOT EXISTS {schema}"))
632 .execute(conn.as_mut())
633 .await?;
634
635 validate_migrations(&mut config.migrations)?;
637 let migrations =
638 add_custom_migrations(default_migrations(), config.migrations).collect::<Vec<_>>();
639
640 let runner = refinery::Runner::new(&migrations).set_grouped(true);
643
644 if config.no_migrations {
645 let last_applied = runner
648 .get_last_applied_migration_async(&mut Migrator::from(&mut conn))
649 .await?;
650 let last_expected = migrations.last();
651 if last_applied.as_ref() != last_expected {
652 return Err(Error::msg(format!(
653 "DB is out of date: last applied migration is {last_applied:?}, but expected \
654 {last_expected:?}"
655 )));
656 }
657 } else {
658 match runner.run_async(&mut Migrator::from(&mut conn)).await {
660 Ok(report) => {
661 tracing::info!("ran DB migrations: {report:?}");
662 },
663 Err(err) => {
664 tracing::error!("DB migrations failed: {:?}", err.report());
665 Err(err)?;
666 },
667 }
668 }
669
670 if config.archive {
671 query("DELETE FROM pruned_height WHERE id = 1")
674 .execute(conn.as_mut())
675 .await?;
676 }
677
678 conn.close().await?;
679
680 Ok(Self {
681 pool,
682 pool_metrics,
683 metrics,
684 pruner_cfg,
685 })
686 }
687}
688
689impl PrunerConfig for SqlStorage {
690 fn set_pruning_config(&mut self, cfg: PrunerCfg) {
691 self.pruner_cfg = Some(cfg);
692 }
693
694 fn get_pruning_config(&self) -> Option<PrunerCfg> {
695 self.pruner_cfg.clone()
696 }
697}
698
699impl HasMetrics for SqlStorage {
700 fn metrics(&self) -> &PrometheusMetrics {
701 &self.metrics
702 }
703}
704
705impl SqlStorage {
706 async fn prune_write(&self) -> anyhow::Result<Transaction<Prune>> {
707 Transaction::new(&self.pool, self.pool_metrics.clone()).await
708 }
709
710 async fn get_minimum_height(&self) -> QueryResult<Option<u64>> {
711 let mut tx = self.read().await.map_err(|err| QueryError::Error {
712 message: err.to_string(),
713 })?;
714 let (Some(height),) =
715 query_as::<(Option<i64>,)>("SELECT MIN(height) as height FROM header")
716 .fetch_one(tx.as_mut())
717 .await?
718 else {
719 return Ok(None);
720 };
721 Ok(Some(height as u64))
722 }
723
724 async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult<Option<u64>> {
725 let mut tx = self.read().await.map_err(|err| QueryError::Error {
726 message: err.to_string(),
727 })?;
728
729 let Some((height,)) = query_as::<(i64,)>(
736 "SELECT height FROM header
737 WHERE timestamp <= $1
738 ORDER BY timestamp DESC, height DESC
739 LIMIT 1",
740 )
741 .bind(timestamp)
742 .fetch_optional(tx.as_mut())
743 .await?
744 else {
745 return Ok(None);
746 };
747 Ok(Some(height as u64))
748 }
749
750 pub async fn get_vid_share<Types>(&self, block_id: BlockId<Types>) -> QueryResult<VidShare>
752 where
753 Types: NodeType,
754 Header<Types>: QueryableHeader<Types>,
755 {
756 let mut tx = self.read().await.map_err(|err| QueryError::Error {
757 message: err.to_string(),
758 })?;
759 let share = tx.vid_share(block_id).await?;
760 Ok(share)
761 }
762
763 pub async fn get_vid_common<Types: NodeType>(
765 &self,
766 block_id: BlockId<Types>,
767 ) -> QueryResult<VidCommonQueryData<Types>>
768 where
769 <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
770 <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
771 {
772 let mut tx = self.read().await.map_err(|err| QueryError::Error {
773 message: err.to_string(),
774 })?;
775 let common = tx.get_vid_common(block_id).await?;
776 Ok(common)
777 }
778
779 pub async fn get_vid_common_metadata<Types: NodeType>(
781 &self,
782 block_id: BlockId<Types>,
783 ) -> QueryResult<VidCommonMetadata<Types>>
784 where
785 <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
786 <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
787 {
788 let mut tx = self.read().await.map_err(|err| QueryError::Error {
789 message: err.to_string(),
790 })?;
791 let common = tx.get_vid_common_metadata(block_id).await?;
792 Ok(common)
793 }
794}
795
796#[async_trait]
797impl PruneStorage for SqlStorage {
798 type Pruner = Pruner;
799
800 async fn get_disk_usage(&self) -> anyhow::Result<u64> {
801 let mut tx = self.read().await?;
802
803 #[cfg(not(feature = "embedded-db"))]
804 let query = "SELECT pg_database_size(current_database())";
805
806 #[cfg(feature = "embedded-db")]
807 let query = "
808 SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) \
809 AS total_bytes";
810
811 let row = tx.fetch_one(query).await?;
812 let size: i64 = row.get(0);
813
814 Ok(size as u64)
815 }
816
817 #[cfg(feature = "embedded-db")]
822 async fn vacuum(&self) -> anyhow::Result<()> {
823 let config = self.get_pruning_config().ok_or(QueryError::Error {
824 message: "Pruning config not found".to_string(),
825 })?;
826 let mut conn = self.pool().acquire().await?;
827 query(&format!(
828 "PRAGMA incremental_vacuum({})",
829 config.incremental_vacuum_pages()
830 ))
831 .execute(conn.as_mut())
832 .await?;
833 conn.close().await?;
834 Ok(())
835 }
836
837 async fn prune(&self, pruner: &mut Pruner) -> anyhow::Result<Option<u64>> {
842 let cfg = self.get_pruning_config().ok_or(QueryError::Error {
843 message: "Pruning config not found".to_string(),
844 })?;
845 let batch_size = cfg.batch_size();
846 let max_usage = cfg.max_usage();
847 let state_tables = cfg.state_tables();
848
849 let mut minimum_retention_height = pruner.minimum_retention_height;
854 let mut target_height = pruner.target_height;
855 let pruned_height = match pruner.pruned_height {
856 Some(h) => Some(h),
857 None => {
858 let Some(height) = self.get_minimum_height().await? else {
859 tracing::info!("database is empty, nothing to prune");
860 return Ok(None);
861 };
862
863 if height > 0 { Some(height - 1) } else { None }
864 },
865 };
866
867 if pruner.target_height.is_none() {
869 let th = self
870 .get_height_by_timestamp(
871 Utc::now().timestamp() - (cfg.target_retention().as_secs()) as i64,
872 )
873 .await?;
874 target_height = th;
875 pruner.target_height = target_height;
876 };
877
878 if let Some(th) = target_height
879 && pruned_height < Some(th)
880 {
881 let batch_end = match pruned_height {
882 None => batch_size - 1,
883 Some(h) => h + batch_size,
884 };
885 let to = min(batch_end, th);
886
887 let mut tx = self.write().await?;
890 tx.save_pruned_height(to).await?;
891 tx.commit().await.map_err(|e| QueryError::Error {
892 message: format!("failed to commit save_pruned_height {e}"),
893 })?;
894
895 let mut tx = self.prune_write().await?;
896 tx.delete_batch(to).await?;
897 tx.commit().await.map_err(|e| QueryError::Error {
898 message: format!("failed to commit delete_batch {e}"),
899 })?;
900
901 let mut tx = self.prune_write().await?;
903 tx.delete_state_batch(state_tables, to).await?;
904 tx.commit().await.map_err(|e| QueryError::Error {
905 message: format!("failed to commit {e}"),
906 })?;
907
908 pruner.pruned_height = Some(to);
909 return Ok(Some(to));
910 }
911
912 if let Some(threshold) = cfg.pruning_threshold() {
915 let usage = self.get_disk_usage().await?;
916
917 if usage > threshold {
920 tracing::warn!(
921 "Disk usage {usage} exceeds pruning threshold {:?}",
922 cfg.pruning_threshold()
923 );
924
925 if minimum_retention_height.is_none() {
926 minimum_retention_height = self
927 .get_height_by_timestamp(
928 Utc::now().timestamp() - (cfg.minimum_retention().as_secs()) as i64,
929 )
930 .await?;
931
932 pruner.minimum_retention_height = minimum_retention_height;
933 }
934
935 if let Some(min_retention_height) = minimum_retention_height
936 && (usage as f64 / threshold as f64) > (f64::from(max_usage) / 10000.0)
937 && pruned_height < Some(min_retention_height)
938 {
939 let batch_end = match pruned_height {
940 None => batch_size - 1,
941 Some(h) => h + batch_size,
942 };
943 let to = min(batch_end, min_retention_height);
944 let mut tx = self.write().await?;
947 tx.save_pruned_height(to).await?;
948 tx.commit().await.map_err(|e| QueryError::Error {
949 message: format!("failed to commit save_pruned_height {e}"),
950 })?;
951
952 let mut tx = self.prune_write().await?;
953 tx.delete_batch(to).await?;
954 tx.commit().await.map_err(|e| QueryError::Error {
955 message: format!("failed to commit delete_batch {e}"),
956 })?;
957
958 let mut tx = self.prune_write().await?;
960 tx.delete_state_batch(state_tables, to).await?;
961 tx.commit().await.map_err(|e| QueryError::Error {
962 message: format!("failed to commit {e}"),
963 })?;
964
965 self.vacuum().await?;
966 pruner.pruned_height = Some(to);
967 return Ok(Some(to));
968 }
969 }
970 }
971
972 Ok(None)
973 }
974}
975
976impl VersionedDataSource for SqlStorage {
977 type Transaction<'a>
978 = Transaction<Write>
979 where
980 Self: 'a;
981 type ReadOnly<'a>
982 = Transaction<Read>
983 where
984 Self: 'a;
985
986 async fn write(&self) -> anyhow::Result<Transaction<Write>> {
987 Transaction::new(&self.pool, self.pool_metrics.clone()).await
988 }
989
990 async fn read(&self) -> anyhow::Result<Transaction<Read>> {
991 Transaction::new(&self.pool, self.pool_metrics.clone()).await
992 }
993}
994
995#[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))]
997pub mod testing {
998 #![allow(unused_imports)]
999 use std::{
1000 env,
1001 process::{Command, Stdio},
1002 str::{self, FromStr},
1003 time::Duration,
1004 };
1005
1006 use refinery::Migration;
1007 use test_utils::reserve_tcp_port;
1008 use tokio::{net::TcpStream, time::timeout};
1009
1010 use super::Config;
1011 use crate::testing::sleep;
1012 #[derive(Debug)]
1013 pub struct TmpDb {
1014 #[cfg(not(feature = "embedded-db"))]
1015 host: String,
1016 #[cfg(not(feature = "embedded-db"))]
1017 port: u16,
1018 #[cfg(not(feature = "embedded-db"))]
1019 container_id: String,
1020 #[cfg(feature = "embedded-db")]
1021 db_path: std::path::PathBuf,
1022 #[allow(dead_code)]
1023 persistent: bool,
1024 }
1025 impl TmpDb {
1026 #[cfg(feature = "embedded-db")]
1027 fn init_sqlite_db(persistent: bool) -> Self {
1028 let file = tempfile::Builder::new()
1029 .prefix("sqlite-")
1030 .suffix(".db")
1031 .tempfile()
1032 .unwrap();
1033
1034 let (_, db_path) = file.keep().unwrap();
1035
1036 Self {
1037 db_path,
1038 persistent,
1039 }
1040 }
1041 pub async fn init() -> Self {
1042 #[cfg(feature = "embedded-db")]
1043 return Self::init_sqlite_db(false);
1044
1045 #[cfg(not(feature = "embedded-db"))]
1046 Self::init_postgres(false).await
1047 }
1048
1049 pub async fn persistent() -> Self {
1050 #[cfg(feature = "embedded-db")]
1051 return Self::init_sqlite_db(true);
1052
1053 #[cfg(not(feature = "embedded-db"))]
1054 Self::init_postgres(true).await
1055 }
1056
1057 #[cfg(not(feature = "embedded-db"))]
1058 async fn init_postgres(persistent: bool) -> Self {
1059 let docker_hostname = env::var("DOCKER_HOSTNAME");
1060 let port = reserve_tcp_port().unwrap();
1066 let host = docker_hostname.unwrap_or("localhost".to_string());
1067
1068 let mut cmd = Command::new("docker");
1069 cmd.arg("run")
1070 .arg("-d")
1071 .args(["-p", &format!("{port}:5432")])
1072 .args(["-e", "POSTGRES_PASSWORD=password"]);
1073
1074 if !persistent {
1075 cmd.arg("--rm");
1076 }
1077
1078 let output = cmd.arg("postgres").output().unwrap();
1079 let stdout = str::from_utf8(&output.stdout).unwrap();
1080 let stderr = str::from_utf8(&output.stderr).unwrap();
1081 if !output.status.success() {
1082 panic!("failed to start postgres docker: {stderr}");
1083 }
1084
1085 let container_id = stdout.trim().to_owned();
1088 tracing::info!("launched postgres docker {container_id}");
1089 let db = Self {
1090 host,
1091 port,
1092 container_id: container_id.clone(),
1093 persistent,
1094 };
1095
1096 db.wait_for_ready().await;
1097 db
1098 }
1099
1100 #[cfg(not(feature = "embedded-db"))]
1101 pub fn host(&self) -> String {
1102 self.host.clone()
1103 }
1104
1105 #[cfg(not(feature = "embedded-db"))]
1106 pub fn port(&self) -> u16 {
1107 self.port
1108 }
1109
1110 #[cfg(feature = "embedded-db")]
1111 pub fn path(&self) -> std::path::PathBuf {
1112 self.db_path.clone()
1113 }
1114
1115 pub fn config(&self) -> Config {
1116 #[cfg(feature = "embedded-db")]
1117 let mut cfg = Config::default().db_path(self.db_path.clone());
1118
1119 #[cfg(not(feature = "embedded-db"))]
1120 let mut cfg = Config::default()
1121 .user("postgres")
1122 .password("password")
1123 .host(self.host())
1124 .port(self.port());
1125
1126 cfg = cfg.migrations(vec![
1127 Migration::unapplied(
1128 "V101__create_test_merkle_tree_table.sql",
1129 &TestMerkleTreeMigration::create("test_tree"),
1130 )
1131 .unwrap(),
1132 ]);
1133
1134 cfg
1135 }
1136
1137 #[cfg(not(feature = "embedded-db"))]
1138 pub fn stop_postgres(&mut self) {
1139 tracing::info!(container = self.container_id, "stopping postgres");
1140 let output = Command::new("docker")
1141 .args(["stop", self.container_id.as_str()])
1142 .output()
1143 .unwrap();
1144 assert!(
1145 output.status.success(),
1146 "error killing postgres docker {}: {}",
1147 self.container_id,
1148 str::from_utf8(&output.stderr).unwrap()
1149 );
1150 }
1151
1152 #[cfg(not(feature = "embedded-db"))]
1153 pub async fn start_postgres(&mut self) {
1154 tracing::info!(container = self.container_id, "resuming postgres");
1155 let output = Command::new("docker")
1156 .args(["start", self.container_id.as_str()])
1157 .output()
1158 .unwrap();
1159 assert!(
1160 output.status.success(),
1161 "error starting postgres docker {}: {}",
1162 self.container_id,
1163 str::from_utf8(&output.stderr).unwrap()
1164 );
1165
1166 self.wait_for_ready().await;
1167 }
1168
1169 #[cfg(not(feature = "embedded-db"))]
1170 async fn wait_for_ready(&self) {
1171 let timeout_duration = Duration::from_secs(
1172 env::var("SQL_TMP_DB_CONNECT_TIMEOUT")
1173 .unwrap_or("60".to_string())
1174 .parse()
1175 .expect("SQL_TMP_DB_CONNECT_TIMEOUT must be an integer number of seconds"),
1176 );
1177
1178 if let Err(err) = timeout(timeout_duration, async {
1179 while Command::new("docker")
1180 .args([
1181 "exec",
1182 &self.container_id,
1183 "pg_isready",
1184 "-h",
1185 "localhost",
1186 "-U",
1187 "postgres",
1188 ])
1189 .env("PGPASSWORD", "password")
1190 .stdin(Stdio::null())
1192 .stdout(Stdio::null())
1194 .stderr(Stdio::null())
1195 .status()
1196 .and_then(|status| {
1200 status
1201 .success()
1202 .then_some(true)
1203 .ok_or(std::io::Error::from_raw_os_error(666))
1205 })
1206 .is_err()
1207 {
1208 tracing::warn!("database is not ready");
1209 sleep(Duration::from_secs(1)).await;
1210 }
1211
1212 while let Err(err) =
1218 TcpStream::connect(format!("{}:{}", self.host, self.port)).await
1219 {
1220 tracing::warn!("database is ready, but port is not available to host: {err:#}");
1221 sleep(Duration::from_millis(100)).await;
1222 }
1223 })
1224 .await
1225 {
1226 panic!(
1227 "failed to connect to TmpDb within configured timeout {timeout_duration:?}: \
1228 {err:#}\n{}",
1229 "Consider increasing the timeout by setting SQL_TMP_DB_CONNECT_TIMEOUT"
1230 );
1231 }
1232 }
1233 }
1234
1235 #[cfg(not(feature = "embedded-db"))]
1236 impl Drop for TmpDb {
1237 fn drop(&mut self) {
1238 self.stop_postgres();
1239 }
1240 }
1241
1242 #[cfg(feature = "embedded-db")]
1243 impl Drop for TmpDb {
1244 fn drop(&mut self) {
1245 if !self.persistent {
1246 std::fs::remove_file(self.db_path.clone()).unwrap();
1247 }
1248 }
1249 }
1250
1251 pub struct TestMerkleTreeMigration;
1252
1253 impl TestMerkleTreeMigration {
1254 fn create(name: &str) -> String {
1255 let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") {
1256 (
1257 "TEXT",
1258 "BLOB",
1259 "INTEGER PRIMARY KEY AUTOINCREMENT",
1260 " (json_extract(data, '$.test_merkle_tree_root'))",
1261 )
1262 } else {
1263 (
1264 "BIT(8)",
1265 "BYTEA",
1266 "SERIAL PRIMARY KEY",
1267 "(data->>'test_merkle_tree_root')",
1268 )
1269 };
1270
1271 format!(
1272 "CREATE TABLE IF NOT EXISTS hash
1273 (
1274 id {hash_pk},
1275 value {binary} NOT NULL UNIQUE
1276 );
1277
1278 ALTER TABLE header
1279 ADD column test_merkle_tree_root text
1280 GENERATED ALWAYS as {root_stored_column} STORED;
1281
1282 CREATE TABLE {name}
1283 (
1284 path JSONB NOT NULL,
1285 created BIGINT NOT NULL,
1286 hash_id INT NOT NULL,
1287 children JSONB,
1288 children_bitvec {bit_vec},
1289 idx JSONB,
1290 entry JSONB,
1291 PRIMARY KEY (path, created)
1292 );
1293 CREATE INDEX {name}_created ON {name} (created);"
1294 )
1295 }
1296 }
1297}
1298
1299#[cfg(all(test, not(target_os = "windows")))]
1301mod test {
1302 use std::time::Duration;
1303
1304 use hotshot_example_types::{
1305 node_types::TEST_VERSIONS,
1306 state_types::{TestInstanceState, TestValidatedState},
1307 };
1308 use jf_merkle_tree_compat::{
1309 MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme, prelude::UniversalMerkleTree,
1310 };
1311 use tokio::time::sleep;
1312
1313 use super::{testing::TmpDb, *};
1314 use crate::{
1315 availability::{BlockQueryData, LeafQueryData},
1316 data_source::storage::{
1317 MerklizedStateStorage, UpdateAvailabilityStorage, pruning::PrunedHeightStorage,
1318 },
1319 merklized_state::{MerklizedState, Snapshot, UpdateStateData},
1320 testing::mocks::{MockMerkleTree, MockTypes},
1321 };
1322
1323 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1324 async fn test_migrations() {
1325 let db = TmpDb::init().await;
1326 let cfg = db.config();
1327
1328 let connect = |migrations: bool, custom_migrations| {
1329 let cfg = cfg.clone();
1330 async move {
1331 let mut cfg = cfg.migrations(custom_migrations);
1332 if !migrations {
1333 cfg = cfg.no_migrations();
1334 }
1335 let client = SqlStorage::connect(cfg, StorageConnectionType::Query).await?;
1336 Ok::<_, Error>(client)
1337 }
1338 };
1339
1340 let err = connect(false, vec![]).await.unwrap_err();
1343 tracing::info!("connecting without running migrations failed as expected: {err}");
1344
1345 connect(true, vec![]).await.unwrap();
1347 connect(false, vec![]).await.unwrap();
1349
1350 let migrations = vec![
1354 Migration::unapplied(
1355 "V9999__create_test_table.sql",
1356 "ALTER TABLE test ADD COLUMN data INTEGER;",
1357 )
1358 .unwrap(),
1359 Migration::unapplied(
1360 "V9998__create_test_table.sql",
1361 "CREATE TABLE test (x bigint);",
1362 )
1363 .unwrap(),
1364 ];
1365 connect(true, migrations.clone()).await.unwrap();
1366
1367 let err = connect(false, vec![]).await.unwrap_err();
1370 tracing::info!("connecting without running migrations failed as expected: {err}");
1371
1372 connect(true, migrations).await.unwrap();
1374 }
1375
1376 #[test]
1377 #[cfg(not(feature = "embedded-db"))]
1378 fn test_config_from_str() {
1379 let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap();
1380 assert_eq!(cfg.db_opt.get_username(), "user");
1381 assert_eq!(cfg.db_opt.get_host(), "host");
1382 assert_eq!(cfg.db_opt.get_port(), 8080);
1383 }
1384
1385 #[test]
1386 #[cfg(feature = "embedded-db")]
1387 fn test_config_from_str() {
1388 let cfg = Config::from_str("sqlite://data.db").unwrap();
1389 assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db");
1390 }
1391
1392 async fn vacuum(storage: &SqlStorage) {
1393 #[cfg(feature = "embedded-db")]
1394 let query = "PRAGMA incremental_vacuum(16000)";
1395 #[cfg(not(feature = "embedded-db"))]
1396 let query = "VACUUM";
1397 storage
1398 .pool
1399 .acquire()
1400 .await
1401 .unwrap()
1402 .execute(query)
1403 .await
1404 .unwrap();
1405 }
1406
1407 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1408 async fn test_target_period_pruning() {
1409 let db = TmpDb::init().await;
1410 let cfg = db.config();
1411
1412 let mut storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1413 .await
1414 .unwrap();
1415 let mut leaf = LeafQueryData::<MockTypes>::genesis(
1416 &TestValidatedState::default(),
1417 &TestInstanceState::default(),
1418 TEST_VERSIONS.test,
1419 )
1420 .await;
1421 for i in 0..20 {
1423 leaf.leaf.block_header_mut().block_number = i;
1424 leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1425 let mut tx = storage.write().await.unwrap();
1426 tx.insert_leaf(&leaf).await.unwrap();
1427 tx.commit().await.unwrap();
1428 }
1429
1430 let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1431
1432 storage.set_pruning_config(PrunerCfg::new());
1434 let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1436
1437 vacuum(&storage).await;
1441 assert!(pruned_height.is_none());
1443
1444 let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1445
1446 assert_eq!(
1447 height_after_pruning, height_before_pruning,
1448 "some data has been pruned"
1449 );
1450
1451 storage.set_pruning_config(PrunerCfg::new().with_target_retention(Duration::from_secs(1)));
1453 sleep(Duration::from_secs(2)).await;
1454 let usage_before_pruning = storage.get_disk_usage().await.unwrap();
1455 let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1458 vacuum(&storage).await;
1462
1463 assert!(pruned_height.is_some());
1465 let usage_after_pruning = storage.get_disk_usage().await.unwrap();
1466 let header_rows = storage
1469 .read()
1470 .await
1471 .unwrap()
1472 .fetch_one("select count(*) as count from header")
1473 .await
1474 .unwrap()
1475 .get::<i64, _>("count");
1476 assert_eq!(header_rows, 0);
1478
1479 let leaf_rows = storage
1483 .read()
1484 .await
1485 .unwrap()
1486 .fetch_one("select count(*) as count from leaf")
1487 .await
1488 .unwrap()
1489 .get::<i64, _>("count");
1490 assert_eq!(leaf_rows, 0);
1492
1493 assert!(
1494 usage_before_pruning > usage_after_pruning,
1495 " disk usage should decrease after pruning"
1496 )
1497 }
1498
1499 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1500 async fn test_merklized_state_pruning() {
1501 let db = TmpDb::init().await;
1502 let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1503 .await
1504 .unwrap();
1505
1506 let num_blocks = 10_000u64;
1507 let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
1508 MockMerkleTree::new(MockMerkleTree::tree_height());
1509
1510 let mut tx = storage.write().await.unwrap();
1512 for height in 0..num_blocks {
1513 test_tree.update(height as usize, height as usize).unwrap();
1514
1515 let test_data = serde_json::json!({
1516 MockMerkleTree::header_state_commitment_field():
1517 serde_json::to_value(test_tree.commitment()).unwrap()
1518 });
1519 tx.upsert(
1520 "header",
1521 [
1522 "height",
1523 "hash",
1524 "payload_hash",
1525 "timestamp",
1526 "data",
1527 "ns_table",
1528 ],
1529 ["height"],
1530 [(
1531 height as i64,
1532 format!("hash{height}"),
1533 "ph".to_string(),
1534 0,
1535 test_data,
1536 "ns".to_string(),
1537 )],
1538 )
1539 .await
1540 .unwrap();
1541
1542 let (_, proof) = test_tree.lookup(height as usize).expect_ok().unwrap();
1543 let traversal_path = <usize as ToTraversalPath<8>>::to_traversal_path(
1544 &(height as usize),
1545 test_tree.height(),
1546 );
1547 UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1548 &mut tx,
1549 proof.clone(),
1550 traversal_path,
1551 height,
1552 )
1553 .await
1554 .unwrap();
1555 }
1556 UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(
1557 &mut tx,
1558 num_blocks as usize,
1559 )
1560 .await
1561 .unwrap();
1562 tx.commit().await.unwrap();
1563
1564 let prune_height = 5678u64;
1566 let mut tx = storage.prune_write().await.unwrap();
1567 tx.delete_state_batch(vec!["test_tree".to_string()], prune_height)
1568 .await
1569 .unwrap();
1570 tx.commit().await.unwrap();
1571
1572 let mut tx = storage.read().await.unwrap();
1574 let (duplicates,) = query_as::<(i64,)>(
1575 "SELECT count(*) FROM (SELECT count(*) FROM test_tree WHERE created <= $1 GROUP BY \
1576 path HAVING count(*) > 1) AS s",
1577 )
1578 .bind(prune_height as i64)
1579 .fetch_one(tx.as_mut())
1580 .await
1581 .unwrap();
1582 assert_eq!(
1583 duplicates, 0,
1584 "found {duplicates} paths with duplicate versions at or below prune height"
1585 );
1586
1587 let commitment = test_tree.commitment();
1589 let mut tx = storage.read().await.unwrap();
1590 for key in 0..num_blocks as usize {
1591 let proof = MerklizedStateStorage::<MockTypes, MockMerkleTree, 8>::get_path(
1592 &mut tx,
1593 Snapshot::Index(num_blocks - 1),
1594 key,
1595 )
1596 .await
1597 .unwrap_or_else(|e| panic!("get_path failed for key {key} after pruning: {e:#}"));
1598 assert_eq!(
1599 proof.elem(),
1600 Some(&key),
1601 "proof for key {key} has wrong element: {:?}",
1602 proof.elem()
1603 );
1604 MockMerkleTree::verify(commitment, key, &proof)
1605 .unwrap()
1606 .unwrap();
1607 }
1608 }
1609
1610 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1611 async fn test_minimum_retention_pruning() {
1612 let db = TmpDb::init().await;
1613
1614 let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1615 .await
1616 .unwrap();
1617 let mut leaf = LeafQueryData::<MockTypes>::genesis(
1618 &TestValidatedState::default(),
1619 &TestInstanceState::default(),
1620 TEST_VERSIONS.test,
1621 )
1622 .await;
1623 for i in 0..20 {
1625 leaf.leaf.block_header_mut().block_number = i;
1626 leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1627 let mut tx = storage.write().await.unwrap();
1628 tx.insert_leaf(&leaf).await.unwrap();
1629 tx.commit().await.unwrap();
1630 }
1631
1632 let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1633 let cfg = PrunerCfg::new();
1634 storage.set_pruning_config(cfg.clone().with_pruning_threshold(1));
1639 println!("{:?}", storage.get_pruning_config().unwrap());
1640 let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1643 vacuum(&storage).await;
1647
1648 assert!(pruned_height.is_none());
1650
1651 let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1652
1653 assert_eq!(
1654 height_after_pruning, height_before_pruning,
1655 "some data has been pruned"
1656 );
1657
1658 storage.set_pruning_config(
1660 cfg.with_minimum_retention(Duration::from_secs(1))
1661 .with_pruning_threshold(1),
1662 );
1663 sleep(Duration::from_secs(2)).await;
1665 let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1667 vacuum(&storage).await;
1671
1672 assert!(pruned_height.is_some());
1674 let header_rows = storage
1677 .read()
1678 .await
1679 .unwrap()
1680 .fetch_one("select count(*) as count from header")
1681 .await
1682 .unwrap()
1683 .get::<i64, _>("count");
1684 assert_eq!(header_rows, 0);
1686 }
1687
1688 #[tokio::test]
1689 #[test_log::test]
1690 async fn test_payload_pruning() {
1691 let db = TmpDb::init().await;
1692 let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1693 .await
1694 .unwrap();
1695 storage.set_pruning_config(Default::default());
1696
1697 let mut leaf = LeafQueryData::<MockTypes>::genesis(
1699 &TestValidatedState::default(),
1700 &TestInstanceState::default(),
1701 TEST_VERSIONS.test,
1702 )
1703 .await;
1704 let block = BlockQueryData::<MockTypes>::genesis(
1705 &Default::default(),
1706 &Default::default(),
1707 TEST_VERSIONS.test.base,
1708 )
1709 .await;
1710 let vid = VidCommonQueryData::<MockTypes>::genesis(
1711 &Default::default(),
1712 &Default::default(),
1713 TEST_VERSIONS.test.base,
1714 )
1715 .await;
1716 {
1717 let mut tx = storage.write().await.unwrap();
1718 tx.insert_leaf(&leaf).await.unwrap();
1719 tx.insert_block(&block).await.unwrap();
1720 tx.insert_vid(&vid, None).await.unwrap();
1721 tx.commit().await.unwrap();
1722 }
1723
1724 leaf.leaf.block_header_mut().block_number += 1;
1726 {
1727 let mut tx = storage.write().await.unwrap();
1728 tx.insert_leaf(&leaf).await.unwrap();
1729 tx.commit().await.unwrap();
1730 }
1731 {
1732 let mut tx = storage.read().await.unwrap();
1733 let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1734 .fetch_one(tx.as_mut())
1735 .await
1736 .unwrap();
1737 assert_eq!(num_payloads, 1);
1738 let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1739 .fetch_one(tx.as_mut())
1740 .await
1741 .unwrap();
1742 assert_eq!(num_vid, 1);
1743 }
1744
1745 let pruned_height = storage
1747 .prune(&mut Pruner {
1748 pruned_height: None,
1749 target_height: Some(0),
1750 minimum_retention_height: None,
1751 })
1752 .await
1753 .unwrap();
1754 tracing::info!(?pruned_height, "first pruning run complete");
1755 {
1756 let mut tx = storage.read().await.unwrap();
1757
1758 let err = tx
1760 .get_block(BlockId::<MockTypes>::Number(0))
1761 .await
1762 .unwrap_err();
1763 assert!(matches!(err, QueryError::NotFound), "{err:#}");
1764 let err = tx
1765 .get_vid_common(BlockId::<MockTypes>::Number(0))
1766 .await
1767 .unwrap_err();
1768 assert!(matches!(err, QueryError::NotFound), "{err:#}");
1769
1770 assert_eq!(
1772 tx.get_block(BlockId::<MockTypes>::Number(1)).await.unwrap(),
1773 BlockQueryData::new(leaf.header().clone(), block.payload)
1774 );
1775 assert_eq!(
1776 tx.get_vid_common(BlockId::<MockTypes>::Number(1))
1777 .await
1778 .unwrap(),
1779 VidCommonQueryData::new(leaf.header().clone(), vid.common)
1780 );
1781
1782 let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1783 .fetch_one(tx.as_mut())
1784 .await
1785 .unwrap();
1786 assert_eq!(num_payloads, 1);
1787
1788 let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1789 .fetch_one(tx.as_mut())
1790 .await
1791 .unwrap();
1792 assert_eq!(num_vid, 1);
1793 }
1794
1795 let pruned_height = storage
1797 .prune(&mut Pruner {
1798 pruned_height,
1799 target_height: Some(1),
1800 minimum_retention_height: None,
1801 })
1802 .await
1803 .unwrap();
1804 tracing::info!(?pruned_height, "second pruning run complete");
1805
1806 let mut tx = storage.read().await.unwrap();
1807 for i in 0..2 {
1808 let err = tx
1809 .get_block(BlockId::<MockTypes>::Number(i))
1810 .await
1811 .unwrap_err();
1812 assert!(matches!(err, QueryError::NotFound), "{err:#}");
1813
1814 let err = tx
1815 .get_vid_common(BlockId::<MockTypes>::Number(i))
1816 .await
1817 .unwrap_err();
1818 assert!(matches!(err, QueryError::NotFound), "{err:#}");
1819 }
1820 let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1821 .fetch_one(tx.as_mut())
1822 .await
1823 .unwrap();
1824 assert_eq!(num_payloads, 0);
1825
1826 let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1827 .fetch_one(tx.as_mut())
1828 .await
1829 .unwrap();
1830 assert_eq!(num_vid, 0);
1831 }
1832
1833 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1834 async fn test_pruned_height_storage() {
1835 let db = TmpDb::init().await;
1836 let cfg = db.config();
1837
1838 let storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1839 .await
1840 .unwrap();
1841 assert!(
1842 storage
1843 .read()
1844 .await
1845 .unwrap()
1846 .load_pruned_height()
1847 .await
1848 .unwrap()
1849 .is_none()
1850 );
1851 for height in [10, 20, 30] {
1852 let mut tx = storage.write().await.unwrap();
1853 tx.save_pruned_height(height).await.unwrap();
1854 tx.commit().await.unwrap();
1855 assert_eq!(
1856 storage
1857 .read()
1858 .await
1859 .unwrap()
1860 .load_pruned_height()
1861 .await
1862 .unwrap(),
1863 Some(height)
1864 );
1865 }
1866 }
1867
1868 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1869 async fn test_transaction_upsert_retries() {
1870 let db = TmpDb::init().await;
1871 let config = db.config();
1872
1873 let storage = SqlStorage::connect(config, StorageConnectionType::Query)
1874 .await
1875 .unwrap();
1876
1877 let mut tx = storage.write().await.unwrap();
1878
1879 tx.upsert("does_not_exist", ["test"], ["test"], [(1_i64,)])
1890 .await
1891 .unwrap_err();
1892 }
1893}