hotshot_query_service/data_source/storage/
sql.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(feature = "sql-data-source")]
14use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration};
15
16use async_trait::async_trait;
17use chrono::Utc;
18#[cfg(not(feature = "embedded-db"))]
19use futures::future::FutureExt;
20use hotshot_types::{
21    data::VidShare,
22    traits::{metrics::Metrics, node_implementation::NodeType},
23};
24use itertools::Itertools;
25use log::LevelFilter;
26#[cfg(not(feature = "embedded-db"))]
27use sqlx::postgres::{PgConnectOptions, PgSslMode};
28#[cfg(feature = "embedded-db")]
29use sqlx::sqlite::SqliteConnectOptions;
30use sqlx::{
31    ConnectOptions, Row,
32    pool::{Pool, PoolOptions},
33};
34
35use crate::{
36    Header, QueryError, QueryResult,
37    availability::{QueryableHeader, QueryablePayload, VidCommonMetadata, VidCommonQueryData},
38    data_source::{
39        VersionedDataSource,
40        storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig},
41        update::Transaction as _,
42    },
43    metrics::PrometheusMetrics,
44    node::BlockId,
45    status::HasMetrics,
46};
47pub extern crate sqlx;
48pub use sqlx::{Database, Sqlite};
49
50mod db;
51mod migrate;
52mod queries;
53mod transaction;
54
55pub use anyhow::Error;
56pub use db::*;
57pub use include_dir::include_dir;
58pub use queries::QueryBuilder;
59pub use refinery::Migration;
60pub use transaction::*;
61
62use self::{migrate::Migrator, transaction::PoolMetrics};
63use super::{AvailabilityStorage, NodeStorage};
64// This needs to be reexported so that we can reference it by absolute path relative to this crate
65// in the expansion of `include_migrations`, even when `include_migrations` is invoked from another
66// crate which doesn't have `include_dir` as a dependency.
67pub use crate::include_migrations;
68
69/// Embed migrations from the given directory into the current binary for PostgreSQL or SQLite.
70///
71/// The macro invocation `include_migrations!(path)` evaluates to an expression of type `impl
72/// Iterator<Item = Migration>`. Each migration must be a text file which is an immediate child of
73/// `path`, and there must be no non-migration files in `path`. The migration files must have names
74/// of the form `V${version}__${name}.sql`, where `version` is a positive integer indicating how the
75/// migration is to be ordered relative to other migrations, and `name` is a descriptive name for
76/// the migration.
77///
78/// `path` should be an absolute path. It is possible to give a path relative to the root of the
79/// invoking crate by using environment variable expansions and the `CARGO_MANIFEST_DIR` environment
80/// variable.
81///
82/// As an example, this is the invocation used to load the default migrations from the
83/// `hotshot-query-service` crate. The migrations are located in a directory called `migrations` at
84/// - PostgreSQL migrations are in `/migrations/postgres`.
85/// - SQLite migrations are in `/migrations/sqlite`.
86///
87/// ```
88/// # use hotshot_query_service::data_source::sql::{include_migrations, Migration};
89/// // For PostgreSQL
90/// #[cfg(not(feature = "embedded-db"))]
91///  let mut migrations: Vec<Migration> =
92///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect();
93/// // For SQLite
94/// #[cfg(feature = "embedded-db")]
95/// let mut migrations: Vec<Migration> =
96///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect();
97///
98///     migrations.sort();
99///     assert_eq!(migrations[0].version(), 10);
100///     assert_eq!(migrations[0].name(), "init_schema");
101/// ```
102///
103/// Note that a similar macro is available from Refinery:
104/// [embed_migrations](https://docs.rs/refinery/0.8.11/refinery/macro.embed_migrations.html). This
105/// macro differs in that it evaluates to an iterator of [migrations](Migration), making it an
106/// expression macro, while `embed_migrations` is a statement macro that defines a module which
107/// provides access to the embedded migrations only indirectly via a
108/// [`Runner`](https://docs.rs/refinery/0.8.11/refinery/struct.Runner.html). The direct access to
109/// migrations provided by [`include_migrations`] makes this macro easier to use with
110/// [`Config::migrations`], for combining custom migrations with [`default_migrations`].
111#[macro_export]
112macro_rules! include_migrations {
113    ($dir:tt) => {
114        $crate::data_source::storage::sql::include_dir!($dir)
115            .files()
116            .map(|file| {
117                let path = file.path();
118                let name = path
119                    .file_name()
120                    .and_then(std::ffi::OsStr::to_str)
121                    .unwrap_or_else(|| {
122                        panic!(
123                            "migration file {} must have a non-empty UTF-8 name",
124                            path.display()
125                        )
126                    });
127                let sql = file
128                    .contents_utf8()
129                    .unwrap_or_else(|| panic!("migration file {name} must use UTF-8 encoding"));
130                $crate::data_source::storage::sql::Migration::unapplied(name, sql)
131                    .expect("invalid migration")
132            })
133    };
134}
135
136/// The migrations required to build the default schema for this version of [`SqlStorage`].
137pub fn default_migrations() -> Vec<Migration> {
138    #[cfg(not(feature = "embedded-db"))]
139    let mut migrations =
140        include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::<Vec<_>>();
141
142    #[cfg(feature = "embedded-db")]
143    let mut migrations =
144        include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::<Vec<_>>();
145
146    // Check version uniqueness and sort by version.
147    validate_migrations(&mut migrations).expect("default migrations are invalid");
148
149    // Check that all migration versions are multiples of 100, so that custom migrations can be
150    // inserted in between.
151    for m in &migrations {
152        if m.version() <= 30 {
153            // An older version of this software used intervals of 10 instead of 100. This was
154            // changed to allow more custom migrations between each default migration, but we must
155            // still accept older migrations that followed the older rule.
156            assert!(
157                m.version() > 0 && m.version() % 10 == 0,
158                "legacy default migration version {} is not a positive multiple of 10",
159                m.version()
160            );
161        } else {
162            assert!(
163                m.version() % 100 == 0,
164                "default migration version {} is not a multiple of 100",
165                m.version()
166            );
167        }
168    }
169
170    migrations
171}
172
173/// Validate and preprocess a sequence of migrations.
174///
175/// * Ensure all migrations have distinct versions
176/// * Ensure migrations are sorted by increasing version
177fn validate_migrations(migrations: &mut [Migration]) -> Result<(), Error> {
178    migrations.sort_by_key(|m| m.version());
179
180    // Check version uniqueness.
181    for (prev, next) in migrations.iter().zip(migrations.iter().skip(1)) {
182        if next <= prev {
183            return Err(Error::msg(format!(
184                "migration versions are not strictly increasing ({prev}->{next})"
185            )));
186        }
187    }
188
189    Ok(())
190}
191
192/// Add custom migrations to a default migration sequence.
193///
194/// Migrations in `custom` replace migrations in `default` with the same version. Otherwise, the two
195/// sequences `default` and `custom` are merged so that the resulting sequence is sorted by
196/// ascending version number. Each of `default` and `custom` is assumed to be the output of
197/// [`validate_migrations`]; that is, each is sorted by version and contains no duplicate versions.
198fn add_custom_migrations(
199    default: impl IntoIterator<Item = Migration>,
200    custom: impl IntoIterator<Item = Migration>,
201) -> impl Iterator<Item = Migration> {
202    default
203        .into_iter()
204        // Merge sorted lists, joining pairs of equal version into `EitherOrBoth::Both`.
205        .merge_join_by(custom, |l, r| l.version().cmp(&r.version()))
206        // Prefer the custom migration for a given version when both default and custom versions
207        // are present.
208        .map(|pair| pair.reduce(|_, custom| custom))
209}
210
211#[derive(Clone)]
212pub struct Config {
213    #[cfg(feature = "embedded-db")]
214    db_opt: SqliteConnectOptions,
215
216    #[cfg(not(feature = "embedded-db"))]
217    db_opt: PgConnectOptions,
218
219    pool_opt: PoolOptions<Db>,
220
221    /// Extra pool_opt to allow separately configuring the connection pool for query service
222    #[cfg(not(feature = "embedded-db"))]
223    pool_opt_query: PoolOptions<Db>,
224
225    #[cfg(not(feature = "embedded-db"))]
226    schema: String,
227    reset: bool,
228    migrations: Vec<Migration>,
229    no_migrations: bool,
230    pruner_cfg: Option<PrunerCfg>,
231    archive: bool,
232    pool: Option<Pool<Db>>,
233}
234
235#[cfg(not(feature = "embedded-db"))]
236impl Default for Config {
237    fn default() -> Self {
238        PgConnectOptions::default()
239            .username("postgres")
240            .password("password")
241            .host("localhost")
242            .port(5432)
243            .into()
244    }
245}
246
247#[cfg(feature = "embedded-db")]
248impl Default for Config {
249    fn default() -> Self {
250        SqliteConnectOptions::default()
251            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
252            .busy_timeout(Duration::from_secs(30))
253            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
254            .create_if_missing(true)
255            .into()
256    }
257}
258
259#[cfg(feature = "embedded-db")]
260impl From<SqliteConnectOptions> for Config {
261    fn from(db_opt: SqliteConnectOptions) -> Self {
262        Self {
263            db_opt,
264            pool_opt: PoolOptions::default(),
265            reset: false,
266            migrations: vec![],
267            no_migrations: false,
268            pruner_cfg: None,
269            archive: false,
270            pool: None,
271        }
272    }
273}
274
275#[cfg(not(feature = "embedded-db"))]
276impl From<PgConnectOptions> for Config {
277    fn from(db_opt: PgConnectOptions) -> Self {
278        Self {
279            db_opt,
280            pool_opt: PoolOptions::default(),
281            pool_opt_query: PoolOptions::default(),
282            schema: "hotshot".into(),
283            reset: false,
284            migrations: vec![],
285            no_migrations: false,
286            pruner_cfg: None,
287            archive: false,
288            pool: None,
289        }
290    }
291}
292
293#[cfg(not(feature = "embedded-db"))]
294impl FromStr for Config {
295    type Err = <PgConnectOptions as FromStr>::Err;
296
297    fn from_str(s: &str) -> Result<Self, Self::Err> {
298        Ok(PgConnectOptions::from_str(s)?.into())
299    }
300}
301
302#[cfg(feature = "embedded-db")]
303impl FromStr for Config {
304    type Err = <SqliteConnectOptions as FromStr>::Err;
305
306    fn from_str(s: &str) -> Result<Self, Self::Err> {
307        Ok(SqliteConnectOptions::from_str(s)?.into())
308    }
309}
310
311#[cfg(feature = "embedded-db")]
312impl Config {
313    pub fn busy_timeout(mut self, timeout: Duration) -> Self {
314        self.db_opt = self.db_opt.busy_timeout(timeout);
315        self
316    }
317
318    pub fn db_path(mut self, path: std::path::PathBuf) -> Self {
319        self.db_opt = self.db_opt.filename(path);
320        self
321    }
322}
323
324#[cfg(not(feature = "embedded-db"))]
325impl Config {
326    /// Set the hostname of the database server.
327    ///
328    /// The default is `localhost`.
329    pub fn host(mut self, host: impl Into<String>) -> Self {
330        self.db_opt = self.db_opt.host(&host.into());
331        self
332    }
333
334    /// Set the port on which to connect to the database.
335    ///
336    /// The default is 5432, the default Postgres port.
337    pub fn port(mut self, port: u16) -> Self {
338        self.db_opt = self.db_opt.port(port);
339        self
340    }
341
342    /// Set the DB user to connect as.
343    pub fn user(mut self, user: &str) -> Self {
344        self.db_opt = self.db_opt.username(user);
345        self
346    }
347
348    /// Set a password for connecting to the database.
349    pub fn password(mut self, password: &str) -> Self {
350        self.db_opt = self.db_opt.password(password);
351        self
352    }
353
354    /// Set the name of the database to connect to.
355    pub fn database(mut self, database: &str) -> Self {
356        self.db_opt = self.db_opt.database(database);
357        self
358    }
359
360    /// Use TLS for an encrypted connection to the database.
361    ///
362    /// Note that an encrypted connection may be established even if this option is not set, as long
363    /// as both the client and server support it. This option merely causes connection to fail if an
364    /// encrypted stream cannot be established.
365    pub fn tls(mut self) -> Self {
366        self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require);
367        self
368    }
369
370    /// Set the name of the schema to use for queries.
371    ///
372    /// The default schema is named `hotshot` and is created via the default migrations.
373    pub fn schema(mut self, schema: impl Into<String>) -> Self {
374        self.schema = schema.into();
375        self
376    }
377}
378
379impl Config {
380    /// Sets the database connection pool
381    /// This allows reusing an existing connection pool when building a new `SqlStorage` instance.
382    pub fn pool(mut self, pool: Pool<Db>) -> Self {
383        self.pool = Some(pool);
384        self
385    }
386
387    /// Reset the schema on connection.
388    ///
389    /// When this [`Config`] is used to [`connect`](Self::connect) a
390    /// [`SqlDataSource`](crate::data_source::SqlDataSource), if this option is set, the relevant
391    /// [`schema`](Self::schema) will first be dropped and then recreated, yielding a completely
392    /// fresh instance of the query service.
393    ///
394    /// This is a particularly useful capability for development and staging environments. Still, it
395    /// must be used with extreme caution, as using this will irrevocably delete any data pertaining
396    /// to the query service in the database.
397    pub fn reset_schema(mut self) -> Self {
398        self.reset = true;
399        self
400    }
401
402    /// Add custom migrations to run when connecting to the database.
403    pub fn migrations(mut self, migrations: impl IntoIterator<Item = Migration>) -> Self {
404        self.migrations.extend(migrations);
405        self
406    }
407
408    /// Skip all migrations when connecting to the database.
409    pub fn no_migrations(mut self) -> Self {
410        self.no_migrations = true;
411        self
412    }
413
414    /// Enable pruning with a given configuration.
415    ///
416    /// If [`archive`](Self::archive) was previously specified, this will override it.
417    pub fn pruner_cfg(mut self, cfg: PrunerCfg) -> Result<Self, Error> {
418        cfg.validate()?;
419        self.pruner_cfg = Some(cfg);
420        self.archive = false;
421        Ok(self)
422    }
423
424    /// Disable pruning and reconstruct previously pruned data.
425    ///
426    /// While running without pruning is the default behavior, the default will not try to
427    /// reconstruct data that was pruned in a previous run where pruning was enabled. This option
428    /// instructs the service to run without pruning _and_ reconstruct all previously pruned data by
429    /// fetching from peers.
430    ///
431    /// If [`pruner_cfg`](Self::pruner_cfg) was previously specified, this will override it.
432    pub fn archive(mut self) -> Self {
433        self.pruner_cfg = None;
434        self.archive = true;
435        self
436    }
437
438    /// Set the maximum idle time of a connection.
439    ///
440    /// Any connection which has been open and unused longer than this duration will be
441    /// automatically closed to reduce load on the server.
442    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
443        self.pool_opt = self.pool_opt.idle_timeout(Some(timeout));
444
445        #[cfg(not(feature = "embedded-db"))]
446        {
447            self.pool_opt_query = self.pool_opt_query.idle_timeout(Some(timeout));
448        }
449
450        self
451    }
452
453    /// Set the maximum lifetime of a connection.
454    ///
455    /// Any connection which has been open longer than this duration will be automatically closed
456    /// (and, if needed, replaced), even if it is otherwise healthy. It is good practice to refresh
457    /// even healthy connections once in a while (e.g. daily) in case of resource leaks in the
458    /// server implementation.
459    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
460        self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
461
462        #[cfg(not(feature = "embedded-db"))]
463        {
464            self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
465        }
466
467        self
468    }
469
470    /// Set the minimum number of connections to maintain at any time.
471    ///
472    /// The data source will, to the best of its ability, maintain at least `min` open connections
473    /// at all times. This can be used to reduce the latency hit of opening new connections when at
474    /// least this many simultaneous connections are frequently needed.
475    pub fn min_connections(mut self, min: u32) -> Self {
476        self.pool_opt = self.pool_opt.min_connections(min);
477        self
478    }
479
480    #[cfg(not(feature = "embedded-db"))]
481    pub fn query_min_connections(mut self, min: u32) -> Self {
482        self.pool_opt_query = self.pool_opt_query.min_connections(min);
483        self
484    }
485
486    /// Set the maximum number of connections to maintain at any time.
487    ///
488    /// Once `max` connections are in use simultaneously, further attempts to acquire a connection
489    /// (or begin a transaction) will block until one of the existing connections is released.
490    pub fn max_connections(mut self, max: u32) -> Self {
491        self.pool_opt = self.pool_opt.max_connections(max);
492        self
493    }
494
495    #[cfg(not(feature = "embedded-db"))]
496    pub fn query_max_connections(mut self, max: u32) -> Self {
497        self.pool_opt_query = self.pool_opt_query.max_connections(max);
498        self
499    }
500
501    /// Log at WARN level any time a SQL statement takes longer than `threshold`.
502    ///
503    /// The default threshold is 1s.
504    pub fn slow_statement_threshold(mut self, threshold: Duration) -> Self {
505        self.db_opt = self
506            .db_opt
507            .log_slow_statements(LevelFilter::Warn, threshold);
508        self
509    }
510
511    /// Set the maximum time a single SQL statement is allowed to run before being canceled.
512    ///
513    /// This helps prevent queries from running indefinitely even when the client is dropped
514    #[cfg(not(feature = "embedded-db"))]
515    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
516        // Format duration as milliseconds
517        // PostgreSQL interprets values without units as milliseconds
518        let timeout_ms = timeout.as_millis();
519        self.db_opt = self
520            .db_opt
521            .options([("statement_timeout", timeout_ms.to_string())]);
522        self
523    }
524
525    /// not supported for SQLite.
526    #[cfg(feature = "embedded-db")]
527    pub fn statement_timeout(self, _timeout: Duration) -> Self {
528        self
529    }
530}
531
532/// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database.
533#[derive(Clone, Debug)]
534pub struct SqlStorage {
535    pool: Pool<Db>,
536    metrics: PrometheusMetrics,
537    pool_metrics: PoolMetrics,
538    pruner_cfg: Option<PrunerCfg>,
539}
540
541#[derive(Debug, Default)]
542pub struct Pruner {
543    pruned_height: Option<u64>,
544    target_height: Option<u64>,
545    minimum_retention_height: Option<u64>,
546}
547
548#[derive(PartialEq)]
549pub enum StorageConnectionType {
550    Sequencer,
551    Query,
552}
553
554impl SqlStorage {
555    pub fn pool(&self) -> Pool<Db> {
556        self.pool.clone()
557    }
558
559    /// Connect to a remote database.
560    #[allow(unused_variables)]
561    pub async fn connect(
562        mut config: Config,
563        connection_type: StorageConnectionType,
564    ) -> Result<Self, Error> {
565        let metrics = PrometheusMetrics::default();
566        let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into()));
567
568        #[cfg(feature = "embedded-db")]
569        let pool = config.pool_opt.clone();
570        #[cfg(not(feature = "embedded-db"))]
571        let pool = match connection_type {
572            StorageConnectionType::Sequencer => config.pool_opt.clone(),
573            StorageConnectionType::Query => config.pool_opt_query.clone(),
574        };
575
576        let pruner_cfg = config.pruner_cfg;
577
578        // Only reuse the same pool if we're using sqlite
579        if cfg!(feature = "embedded-db") || connection_type == StorageConnectionType::Sequencer {
580            // re-use the same pool if present and return early
581            if let Some(pool) = config.pool {
582                return Ok(Self {
583                    metrics,
584                    pool_metrics,
585                    pool,
586                    pruner_cfg,
587                });
588            }
589        } else if config.pool.is_some() {
590            tracing::info!("not reusing existing pool for query connection");
591        }
592
593        #[cfg(not(feature = "embedded-db"))]
594        let schema = config.schema.clone();
595        #[cfg(not(feature = "embedded-db"))]
596        let pool = pool.after_connect(move |conn, _| {
597            let schema = config.schema.clone();
598            async move {
599                query(&format!("SET search_path TO {schema}"))
600                    .execute(conn)
601                    .await?;
602                Ok(())
603            }
604            .boxed()
605        });
606
607        #[cfg(feature = "embedded-db")]
608        if config.reset {
609            std::fs::remove_file(config.db_opt.get_filename())?;
610        }
611
612        let pool = pool.connect_with(config.db_opt).await?;
613
614        // Create or connect to the schema for this query service.
615        let mut conn = pool.acquire().await?;
616
617        // Disable statement timeout for migrations, as they can take a long time
618        #[cfg(not(feature = "embedded-db"))]
619        query("SET statement_timeout = 0")
620            .execute(conn.as_mut())
621            .await?;
622
623        #[cfg(not(feature = "embedded-db"))]
624        if config.reset {
625            query(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"))
626                .execute(conn.as_mut())
627                .await?;
628        }
629
630        #[cfg(not(feature = "embedded-db"))]
631        query(&format!("CREATE SCHEMA IF NOT EXISTS {schema}"))
632            .execute(conn.as_mut())
633            .await?;
634
635        // Get migrations and interleave with custom migrations, sorting by version number.
636        validate_migrations(&mut config.migrations)?;
637        let migrations =
638            add_custom_migrations(default_migrations(), config.migrations).collect::<Vec<_>>();
639
640        // Get a migration runner. Depending on the config, we can either use this to actually run
641        // the migrations or just check if the database is up to date.
642        let runner = refinery::Runner::new(&migrations).set_grouped(true);
643
644        if config.no_migrations {
645            // We've been asked not to run any migrations. Abort if the DB is not already up to
646            // date.
647            let last_applied = runner
648                .get_last_applied_migration_async(&mut Migrator::from(&mut conn))
649                .await?;
650            let last_expected = migrations.last();
651            if last_applied.as_ref() != last_expected {
652                return Err(Error::msg(format!(
653                    "DB is out of date: last applied migration is {last_applied:?}, but expected \
654                     {last_expected:?}"
655                )));
656            }
657        } else {
658            // Run migrations using `refinery`.
659            match runner.run_async(&mut Migrator::from(&mut conn)).await {
660                Ok(report) => {
661                    tracing::info!("ran DB migrations: {report:?}");
662                },
663                Err(err) => {
664                    tracing::error!("DB migrations failed: {:?}", err.report());
665                    Err(err)?;
666                },
667            }
668        }
669
670        if config.archive {
671            // If running in archive mode, ensure the pruned height is set to 0, so the fetcher will
672            // reconstruct previously pruned data.
673            query("DELETE FROM pruned_height WHERE id = 1")
674                .execute(conn.as_mut())
675                .await?;
676        }
677
678        conn.close().await?;
679
680        Ok(Self {
681            pool,
682            pool_metrics,
683            metrics,
684            pruner_cfg,
685        })
686    }
687}
688
689impl PrunerConfig for SqlStorage {
690    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
691        self.pruner_cfg = Some(cfg);
692    }
693
694    fn get_pruning_config(&self) -> Option<PrunerCfg> {
695        self.pruner_cfg.clone()
696    }
697}
698
699impl HasMetrics for SqlStorage {
700    fn metrics(&self) -> &PrometheusMetrics {
701        &self.metrics
702    }
703}
704
705impl SqlStorage {
706    async fn get_minimum_height(&self) -> QueryResult<Option<u64>> {
707        let mut tx = self.read().await.map_err(|err| QueryError::Error {
708            message: err.to_string(),
709        })?;
710        let (Some(height),) =
711            query_as::<(Option<i64>,)>("SELECT MIN(height) as height FROM header")
712                .fetch_one(tx.as_mut())
713                .await?
714        else {
715            return Ok(None);
716        };
717        Ok(Some(height as u64))
718    }
719
720    async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult<Option<u64>> {
721        let mut tx = self.read().await.map_err(|err| QueryError::Error {
722            message: err.to_string(),
723        })?;
724
725        // We order by timestamp and then height, even though logically this is no different than
726        // just ordering by height, since timestamps are monotonic. The reason is that this order
727        // allows the query planner to efficiently solve the where clause and presort the results
728        // based on the timestamp index. The remaining sort on height, which guarantees a unique
729        // block if multiple blocks have the same timestamp, is very efficient, because there are
730        // never more than a handful of blocks with the same timestamp.
731        let Some((height,)) = query_as::<(i64,)>(
732            "SELECT height FROM header
733              WHERE timestamp <= $1
734              ORDER BY timestamp DESC, height DESC
735              LIMIT 1",
736        )
737        .bind(timestamp)
738        .fetch_optional(tx.as_mut())
739        .await?
740        else {
741            return Ok(None);
742        };
743        Ok(Some(height as u64))
744    }
745
746    /// Get the stored VID share for a given block, if one exists.
747    pub async fn get_vid_share<Types>(&self, block_id: BlockId<Types>) -> QueryResult<VidShare>
748    where
749        Types: NodeType,
750        Header<Types>: QueryableHeader<Types>,
751    {
752        let mut tx = self.read().await.map_err(|err| QueryError::Error {
753            message: err.to_string(),
754        })?;
755        let share = tx.vid_share(block_id).await?;
756        Ok(share)
757    }
758
759    /// Get the stored VID common data for a given block, if one exists.
760    pub async fn get_vid_common<Types: NodeType>(
761        &self,
762        block_id: BlockId<Types>,
763    ) -> QueryResult<VidCommonQueryData<Types>>
764    where
765        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
766        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
767    {
768        let mut tx = self.read().await.map_err(|err| QueryError::Error {
769            message: err.to_string(),
770        })?;
771        let common = tx.get_vid_common(block_id).await?;
772        Ok(common)
773    }
774
775    /// Get the stored VID common metadata for a given block, if one exists.
776    pub async fn get_vid_common_metadata<Types: NodeType>(
777        &self,
778        block_id: BlockId<Types>,
779    ) -> QueryResult<VidCommonMetadata<Types>>
780    where
781        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
782        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
783    {
784        let mut tx = self.read().await.map_err(|err| QueryError::Error {
785            message: err.to_string(),
786        })?;
787        let common = tx.get_vid_common_metadata(block_id).await?;
788        Ok(common)
789    }
790}
791
792#[async_trait]
793impl PruneStorage for SqlStorage {
794    type Pruner = Pruner;
795
796    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
797        let mut tx = self.read().await?;
798
799        #[cfg(not(feature = "embedded-db"))]
800        let query = "SELECT pg_database_size(current_database())";
801
802        #[cfg(feature = "embedded-db")]
803        let query = "
804            SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) \
805                     AS total_bytes";
806
807        let row = tx.fetch_one(query).await?;
808        let size: i64 = row.get(0);
809
810        Ok(size as u64)
811    }
812
813    /// Trigger incremental vacuum to free up space in the SQLite database.
814    /// Note: We don't vacuum the Postgres database,
815    /// as there is no manual trigger for incremental vacuum,
816    /// and a full vacuum can take a lot of time.
817    #[cfg(feature = "embedded-db")]
818    async fn vacuum(&self) -> anyhow::Result<()> {
819        let config = self.get_pruning_config().ok_or(QueryError::Error {
820            message: "Pruning config not found".to_string(),
821        })?;
822        let mut conn = self.pool().acquire().await?;
823        query(&format!(
824            "PRAGMA incremental_vacuum({})",
825            config.incremental_vacuum_pages()
826        ))
827        .execute(conn.as_mut())
828        .await?;
829        conn.close().await?;
830        Ok(())
831    }
832
833    /// Note: The prune operation may not immediately free up space even after rows are deleted.
834    /// This is because a vacuum operation may be necessary to reclaim more space.
835    /// PostgreSQL already performs auto vacuuming, so we are not including it here
836    /// as running a vacuum operation can be resource-intensive.
837    async fn prune(&self, pruner: &mut Pruner) -> anyhow::Result<Option<u64>> {
838        let cfg = self.get_pruning_config().ok_or(QueryError::Error {
839            message: "Pruning config not found".to_string(),
840        })?;
841        let batch_size = cfg.batch_size();
842        let max_usage = cfg.max_usage();
843        let state_tables = cfg.state_tables();
844
845        // If a pruner run was already in progress, some variables may already be set,
846        // depending on whether a batch was deleted and which batch it was (target or minimum retention).
847        // This enables us to resume the pruner run from the exact heights.
848        // If any of these values are not set, they can be loaded from the database if necessary.
849        let mut minimum_retention_height = pruner.minimum_retention_height;
850        let mut target_height = pruner.target_height;
851        let pruned_height = match pruner.pruned_height {
852            Some(h) => Some(h),
853            None => {
854                let Some(height) = self.get_minimum_height().await? else {
855                    tracing::info!("database is empty, nothing to prune");
856                    return Ok(None);
857                };
858
859                if height > 0 { Some(height - 1) } else { None }
860            },
861        };
862
863        // Prune data exceeding target retention in batches
864        if pruner.target_height.is_none() {
865            let th = self
866                .get_height_by_timestamp(
867                    Utc::now().timestamp() - (cfg.target_retention().as_secs()) as i64,
868                )
869                .await?;
870            target_height = th;
871            pruner.target_height = target_height;
872        };
873
874        if let Some(th) = target_height
875            && pruned_height < Some(th)
876        {
877            let batch_end = match pruned_height {
878                None => batch_size - 1,
879                Some(h) => h + batch_size,
880            };
881            let to = min(batch_end, th);
882
883            // Update pruned height first so the fetcher does not
884            // try to fetch data that we are about to delete.
885            let mut tx = self.write().await?;
886            tx.save_pruned_height(to).await?;
887            tx.commit().await.map_err(|e| QueryError::Error {
888                message: format!("failed to commit save_pruned_height {e}"),
889            })?;
890
891            let mut tx = self.write().await?;
892            tx.delete_batch(to).await?;
893            tx.commit().await.map_err(|e| QueryError::Error {
894                message: format!("failed to commit delete_batch {e}"),
895            })?;
896
897            // Prune state tables in a separate transaction.
898            let mut tx = self.write().await?;
899            tx.delete_state_batch(state_tables, to).await?;
900            tx.commit().await.map_err(|e| QueryError::Error {
901                message: format!("failed to commit {e}"),
902            })?;
903
904            pruner.pruned_height = Some(to);
905            return Ok(Some(to));
906        }
907
908        // If threshold is set, prune data exceeding minimum retention in batches
909        // This parameter is needed for SQL storage as there is no direct way to get free space.
910        if let Some(threshold) = cfg.pruning_threshold() {
911            let usage = self.get_disk_usage().await?;
912
913            // Prune data exceeding minimum retention in batches starting from minimum height
914            // until usage is below threshold
915            if usage > threshold {
916                tracing::warn!(
917                    "Disk usage {usage} exceeds pruning threshold {:?}",
918                    cfg.pruning_threshold()
919                );
920
921                if minimum_retention_height.is_none() {
922                    minimum_retention_height = self
923                        .get_height_by_timestamp(
924                            Utc::now().timestamp() - (cfg.minimum_retention().as_secs()) as i64,
925                        )
926                        .await?;
927
928                    pruner.minimum_retention_height = minimum_retention_height;
929                }
930
931                if let Some(min_retention_height) = minimum_retention_height
932                    && (usage as f64 / threshold as f64) > (f64::from(max_usage) / 10000.0)
933                    && pruned_height < Some(min_retention_height)
934                {
935                    let batch_end = match pruned_height {
936                        None => batch_size - 1,
937                        Some(h) => h + batch_size,
938                    };
939                    let to = min(batch_end, min_retention_height);
940                    // Update pruned height first so the fetcher does not
941                    // try to fetch data that we are about to delete.
942                    let mut tx = self.write().await?;
943                    tx.save_pruned_height(to).await?;
944                    tx.commit().await.map_err(|e| QueryError::Error {
945                        message: format!("failed to commit save_pruned_height {e}"),
946                    })?;
947
948                    let mut tx = self.write().await?;
949                    tx.delete_batch(to).await?;
950                    tx.commit().await.map_err(|e| QueryError::Error {
951                        message: format!("failed to commit delete_batch {e}"),
952                    })?;
953
954                    // Prune state tables in a separate transaction.
955                    let mut tx = self.write().await?;
956                    tx.delete_state_batch(state_tables, to).await?;
957                    tx.commit().await.map_err(|e| QueryError::Error {
958                        message: format!("failed to commit {e}"),
959                    })?;
960
961                    self.vacuum().await?;
962                    pruner.pruned_height = Some(to);
963                    return Ok(Some(to));
964                }
965            }
966        }
967
968        Ok(None)
969    }
970}
971
972impl VersionedDataSource for SqlStorage {
973    type Transaction<'a>
974        = Transaction<Write>
975    where
976        Self: 'a;
977    type ReadOnly<'a>
978        = Transaction<Read>
979    where
980        Self: 'a;
981
982    async fn write(&self) -> anyhow::Result<Transaction<Write>> {
983        Transaction::new(&self.pool, self.pool_metrics.clone()).await
984    }
985
986    async fn read(&self) -> anyhow::Result<Transaction<Read>> {
987        Transaction::new(&self.pool, self.pool_metrics.clone()).await
988    }
989}
990
991// These tests run the `postgres` Docker image, which doesn't work on Windows.
992#[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))]
993pub mod testing {
994    #![allow(unused_imports)]
995    use std::{
996        env,
997        process::{Command, Stdio},
998        str::{self, FromStr},
999        time::Duration,
1000    };
1001
1002    use refinery::Migration;
1003    use test_utils::reserve_tcp_port;
1004    use tokio::{net::TcpStream, time::timeout};
1005
1006    use super::Config;
1007    use crate::testing::sleep;
1008    #[derive(Debug)]
1009    pub struct TmpDb {
1010        #[cfg(not(feature = "embedded-db"))]
1011        host: String,
1012        #[cfg(not(feature = "embedded-db"))]
1013        port: u16,
1014        #[cfg(not(feature = "embedded-db"))]
1015        container_id: String,
1016        #[cfg(feature = "embedded-db")]
1017        db_path: std::path::PathBuf,
1018        #[allow(dead_code)]
1019        persistent: bool,
1020    }
1021    impl TmpDb {
1022        #[cfg(feature = "embedded-db")]
1023        fn init_sqlite_db(persistent: bool) -> Self {
1024            let file = tempfile::Builder::new()
1025                .prefix("sqlite-")
1026                .suffix(".db")
1027                .tempfile()
1028                .unwrap();
1029
1030            let (_, db_path) = file.keep().unwrap();
1031
1032            Self {
1033                db_path,
1034                persistent,
1035            }
1036        }
1037        pub async fn init() -> Self {
1038            #[cfg(feature = "embedded-db")]
1039            return Self::init_sqlite_db(false);
1040
1041            #[cfg(not(feature = "embedded-db"))]
1042            Self::init_postgres(false).await
1043        }
1044
1045        pub async fn persistent() -> Self {
1046            #[cfg(feature = "embedded-db")]
1047            return Self::init_sqlite_db(true);
1048
1049            #[cfg(not(feature = "embedded-db"))]
1050            Self::init_postgres(true).await
1051        }
1052
1053        #[cfg(not(feature = "embedded-db"))]
1054        async fn init_postgres(persistent: bool) -> Self {
1055            let docker_hostname = env::var("DOCKER_HOSTNAME");
1056            // This picks an unused port on the current system.  If docker is
1057            // configured to run on a different host then this may not find a
1058            // "free" port on that system.
1059            // We *might* be able to get away with this as any remote docker
1060            // host should hopefully be pretty open with it's port space.
1061            let port = reserve_tcp_port().unwrap();
1062            let host = docker_hostname.unwrap_or("localhost".to_string());
1063
1064            let mut cmd = Command::new("docker");
1065            cmd.arg("run")
1066                .arg("-d")
1067                .args(["-p", &format!("{port}:5432")])
1068                .args(["-e", "POSTGRES_PASSWORD=password"]);
1069
1070            if !persistent {
1071                cmd.arg("--rm");
1072            }
1073
1074            let output = cmd.arg("postgres").output().unwrap();
1075            let stdout = str::from_utf8(&output.stdout).unwrap();
1076            let stderr = str::from_utf8(&output.stderr).unwrap();
1077            if !output.status.success() {
1078                panic!("failed to start postgres docker: {stderr}");
1079            }
1080
1081            // Create the TmpDb object immediately after starting the Docker container, so if
1082            // anything panics after this `drop` will be called and we will clean up.
1083            let container_id = stdout.trim().to_owned();
1084            tracing::info!("launched postgres docker {container_id}");
1085            let db = Self {
1086                host,
1087                port,
1088                container_id: container_id.clone(),
1089                persistent,
1090            };
1091
1092            db.wait_for_ready().await;
1093            db
1094        }
1095
1096        #[cfg(not(feature = "embedded-db"))]
1097        pub fn host(&self) -> String {
1098            self.host.clone()
1099        }
1100
1101        #[cfg(not(feature = "embedded-db"))]
1102        pub fn port(&self) -> u16 {
1103            self.port
1104        }
1105
1106        #[cfg(feature = "embedded-db")]
1107        pub fn path(&self) -> std::path::PathBuf {
1108            self.db_path.clone()
1109        }
1110
1111        pub fn config(&self) -> Config {
1112            #[cfg(feature = "embedded-db")]
1113            let mut cfg = Config::default().db_path(self.db_path.clone());
1114
1115            #[cfg(not(feature = "embedded-db"))]
1116            let mut cfg = Config::default()
1117                .user("postgres")
1118                .password("password")
1119                .host(self.host())
1120                .port(self.port());
1121
1122            cfg = cfg.migrations(vec![
1123                Migration::unapplied(
1124                    "V101__create_test_merkle_tree_table.sql",
1125                    &TestMerkleTreeMigration::create("test_tree"),
1126                )
1127                .unwrap(),
1128            ]);
1129
1130            cfg
1131        }
1132
1133        #[cfg(not(feature = "embedded-db"))]
1134        pub fn stop_postgres(&mut self) {
1135            tracing::info!(container = self.container_id, "stopping postgres");
1136            let output = Command::new("docker")
1137                .args(["stop", self.container_id.as_str()])
1138                .output()
1139                .unwrap();
1140            assert!(
1141                output.status.success(),
1142                "error killing postgres docker {}: {}",
1143                self.container_id,
1144                str::from_utf8(&output.stderr).unwrap()
1145            );
1146        }
1147
1148        #[cfg(not(feature = "embedded-db"))]
1149        pub async fn start_postgres(&mut self) {
1150            tracing::info!(container = self.container_id, "resuming postgres");
1151            let output = Command::new("docker")
1152                .args(["start", self.container_id.as_str()])
1153                .output()
1154                .unwrap();
1155            assert!(
1156                output.status.success(),
1157                "error starting postgres docker {}: {}",
1158                self.container_id,
1159                str::from_utf8(&output.stderr).unwrap()
1160            );
1161
1162            self.wait_for_ready().await;
1163        }
1164
1165        #[cfg(not(feature = "embedded-db"))]
1166        async fn wait_for_ready(&self) {
1167            let timeout_duration = Duration::from_secs(
1168                env::var("SQL_TMP_DB_CONNECT_TIMEOUT")
1169                    .unwrap_or("60".to_string())
1170                    .parse()
1171                    .expect("SQL_TMP_DB_CONNECT_TIMEOUT must be an integer number of seconds"),
1172            );
1173
1174            if let Err(err) = timeout(timeout_duration, async {
1175                while Command::new("docker")
1176                    .args([
1177                        "exec",
1178                        &self.container_id,
1179                        "pg_isready",
1180                        "-h",
1181                        "localhost",
1182                        "-U",
1183                        "postgres",
1184                    ])
1185                    .env("PGPASSWORD", "password")
1186                    // Null input so the command terminates as soon as it manages to connect.
1187                    .stdin(Stdio::null())
1188                    // Discard command output.
1189                    .stdout(Stdio::null())
1190                    .stderr(Stdio::null())
1191                    .status()
1192                    // We should ensure the exit status. A simple `unwrap`
1193                    // would panic on unrelated errors (such as network
1194                    // connection failures)
1195                    .and_then(|status| {
1196                        status
1197                            .success()
1198                            .then_some(true)
1199                            // Any ol' Error will do
1200                            .ok_or(std::io::Error::from_raw_os_error(666))
1201                    })
1202                    .is_err()
1203                {
1204                    tracing::warn!("database is not ready");
1205                    sleep(Duration::from_secs(1)).await;
1206                }
1207
1208                // The above command ensures the database is ready inside the Docker container.
1209                // However, on some systems, there is a slight delay before the port is exposed via
1210                // host networking. We don't need to check again that the database is ready on the
1211                // host (and maybe can't, because the host might not have pg_isready installed), but
1212                // we can ensure the port is open by just establishing a TCP connection.
1213                while let Err(err) =
1214                    TcpStream::connect(format!("{}:{}", self.host, self.port)).await
1215                {
1216                    tracing::warn!("database is ready, but port is not available to host: {err:#}");
1217                    sleep(Duration::from_millis(100)).await;
1218                }
1219            })
1220            .await
1221            {
1222                panic!(
1223                    "failed to connect to TmpDb within configured timeout {timeout_duration:?}: \
1224                     {err:#}\n{}",
1225                    "Consider increasing the timeout by setting SQL_TMP_DB_CONNECT_TIMEOUT"
1226                );
1227            }
1228        }
1229    }
1230
1231    #[cfg(not(feature = "embedded-db"))]
1232    impl Drop for TmpDb {
1233        fn drop(&mut self) {
1234            self.stop_postgres();
1235        }
1236    }
1237
1238    #[cfg(feature = "embedded-db")]
1239    impl Drop for TmpDb {
1240        fn drop(&mut self) {
1241            if !self.persistent {
1242                std::fs::remove_file(self.db_path.clone()).unwrap();
1243            }
1244        }
1245    }
1246
1247    pub struct TestMerkleTreeMigration;
1248
1249    impl TestMerkleTreeMigration {
1250        fn create(name: &str) -> String {
1251            let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") {
1252                (
1253                    "TEXT",
1254                    "BLOB",
1255                    "INTEGER PRIMARY KEY AUTOINCREMENT",
1256                    " (json_extract(data, '$.test_merkle_tree_root'))",
1257                )
1258            } else {
1259                (
1260                    "BIT(8)",
1261                    "BYTEA",
1262                    "SERIAL PRIMARY KEY",
1263                    "(data->>'test_merkle_tree_root')",
1264                )
1265            };
1266
1267            format!(
1268                "CREATE TABLE IF NOT EXISTS hash
1269            (
1270                id {hash_pk},
1271                value {binary}  NOT NULL UNIQUE
1272            );
1273
1274            ALTER TABLE header
1275            ADD column test_merkle_tree_root text
1276            GENERATED ALWAYS as {root_stored_column} STORED;
1277
1278            CREATE TABLE {name}
1279            (
1280                path JSONB NOT NULL,
1281                created BIGINT NOT NULL,
1282                hash_id INT NOT NULL,
1283                children JSONB,
1284                children_bitvec {bit_vec},
1285                idx JSONB,
1286                entry JSONB,
1287                PRIMARY KEY (path, created)
1288            );
1289            CREATE INDEX {name}_created ON {name} (created);"
1290            )
1291        }
1292    }
1293}
1294
1295// These tests run the `postgres` Docker image, which doesn't work on Windows.
1296#[cfg(all(test, not(target_os = "windows")))]
1297mod test {
1298    use std::time::Duration;
1299
1300    use hotshot_example_types::{
1301        node_types::TEST_VERSIONS,
1302        state_types::{TestInstanceState, TestValidatedState},
1303    };
1304    use jf_merkle_tree_compat::{
1305        MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme, prelude::UniversalMerkleTree,
1306    };
1307    use tokio::time::sleep;
1308
1309    use super::{testing::TmpDb, *};
1310    use crate::{
1311        availability::{BlockQueryData, LeafQueryData},
1312        data_source::storage::{
1313            MerklizedStateStorage, UpdateAvailabilityStorage, pruning::PrunedHeightStorage,
1314        },
1315        merklized_state::{MerklizedState, Snapshot, UpdateStateData},
1316        testing::mocks::{MockMerkleTree, MockTypes},
1317    };
1318
1319    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1320    async fn test_migrations() {
1321        let db = TmpDb::init().await;
1322        let cfg = db.config();
1323
1324        let connect = |migrations: bool, custom_migrations| {
1325            let cfg = cfg.clone();
1326            async move {
1327                let mut cfg = cfg.migrations(custom_migrations);
1328                if !migrations {
1329                    cfg = cfg.no_migrations();
1330                }
1331                let client = SqlStorage::connect(cfg, StorageConnectionType::Query).await?;
1332                Ok::<_, Error>(client)
1333            }
1334        };
1335
1336        // Connecting with migrations disabled should fail if the database is not already up to date
1337        // (since we've just created a fresh database, it isn't).
1338        let err = connect(false, vec![]).await.unwrap_err();
1339        tracing::info!("connecting without running migrations failed as expected: {err}");
1340
1341        // Now connect and run migrations to bring the database up to date.
1342        connect(true, vec![]).await.unwrap();
1343        // Now connecting without migrations should work.
1344        connect(false, vec![]).await.unwrap();
1345
1346        // Connect with some custom migrations, to advance the schema even further. Pass in the
1347        // custom migrations out of order; they should still execute in order of version number.
1348        // The SQL commands used here will fail if not run in order.
1349        let migrations = vec![
1350            Migration::unapplied(
1351                "V9999__create_test_table.sql",
1352                "ALTER TABLE test ADD COLUMN data INTEGER;",
1353            )
1354            .unwrap(),
1355            Migration::unapplied(
1356                "V9998__create_test_table.sql",
1357                "CREATE TABLE test (x bigint);",
1358            )
1359            .unwrap(),
1360        ];
1361        connect(true, migrations.clone()).await.unwrap();
1362
1363        // Connect using the default schema (no custom migrations) and not running migrations. This
1364        // should fail because the database is _ahead_ of the client in terms of schema.
1365        let err = connect(false, vec![]).await.unwrap_err();
1366        tracing::info!("connecting without running migrations failed as expected: {err}");
1367
1368        // Connecting with the customized schema should work even without running migrations.
1369        connect(true, migrations).await.unwrap();
1370    }
1371
1372    #[test]
1373    #[cfg(not(feature = "embedded-db"))]
1374    fn test_config_from_str() {
1375        let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap();
1376        assert_eq!(cfg.db_opt.get_username(), "user");
1377        assert_eq!(cfg.db_opt.get_host(), "host");
1378        assert_eq!(cfg.db_opt.get_port(), 8080);
1379    }
1380
1381    #[test]
1382    #[cfg(feature = "embedded-db")]
1383    fn test_config_from_str() {
1384        let cfg = Config::from_str("sqlite://data.db").unwrap();
1385        assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db");
1386    }
1387
1388    async fn vacuum(storage: &SqlStorage) {
1389        #[cfg(feature = "embedded-db")]
1390        let query = "PRAGMA incremental_vacuum(16000)";
1391        #[cfg(not(feature = "embedded-db"))]
1392        let query = "VACUUM";
1393        storage
1394            .pool
1395            .acquire()
1396            .await
1397            .unwrap()
1398            .execute(query)
1399            .await
1400            .unwrap();
1401    }
1402
1403    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1404    async fn test_target_period_pruning() {
1405        let db = TmpDb::init().await;
1406        let cfg = db.config();
1407
1408        let mut storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1409            .await
1410            .unwrap();
1411        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1412            &TestValidatedState::default(),
1413            &TestInstanceState::default(),
1414            TEST_VERSIONS.test,
1415        )
1416        .await;
1417        // insert some mock data
1418        for i in 0..20 {
1419            leaf.leaf.block_header_mut().block_number = i;
1420            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1421            let mut tx = storage.write().await.unwrap();
1422            tx.insert_leaf(&leaf).await.unwrap();
1423            tx.commit().await.unwrap();
1424        }
1425
1426        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1427
1428        // Set pruner config to default which has minimum retention set to 1 day
1429        storage.set_pruning_config(PrunerCfg::new());
1430        // No data will be pruned
1431        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1432
1433        // Vacuum the database to reclaim space.
1434        // This is necessary to ensure the test passes.
1435        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1436        vacuum(&storage).await;
1437        // Pruned height should be none
1438        assert!(pruned_height.is_none());
1439
1440        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1441
1442        assert_eq!(
1443            height_after_pruning, height_before_pruning,
1444            "some data has been pruned"
1445        );
1446
1447        // Set pruner config to target retention set to 1s
1448        storage.set_pruning_config(PrunerCfg::new().with_target_retention(Duration::from_secs(1)));
1449        sleep(Duration::from_secs(2)).await;
1450        let usage_before_pruning = storage.get_disk_usage().await.unwrap();
1451        // All of the data is now older than 1s.
1452        // This would prune all the data as the target retention is set to 1s
1453        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1454        // Vacuum the database to reclaim space.
1455        // This is necessary to ensure the test passes.
1456        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1457        vacuum(&storage).await;
1458
1459        // Pruned height should be some
1460        assert!(pruned_height.is_some());
1461        let usage_after_pruning = storage.get_disk_usage().await.unwrap();
1462        // All the tables should be empty
1463        // counting rows in header table
1464        let header_rows = storage
1465            .read()
1466            .await
1467            .unwrap()
1468            .fetch_one("select count(*) as count from header")
1469            .await
1470            .unwrap()
1471            .get::<i64, _>("count");
1472        // the table should be empty
1473        assert_eq!(header_rows, 0);
1474
1475        // counting rows in leaf table.
1476        // Deleting rows from header table would delete rows in all the tables
1477        // as each of table implement "ON DELETE CASCADE" fk constraint with the header table.
1478        let leaf_rows = storage
1479            .read()
1480            .await
1481            .unwrap()
1482            .fetch_one("select count(*) as count from leaf")
1483            .await
1484            .unwrap()
1485            .get::<i64, _>("count");
1486        // the table should be empty
1487        assert_eq!(leaf_rows, 0);
1488
1489        assert!(
1490            usage_before_pruning > usage_after_pruning,
1491            " disk usage should decrease after pruning"
1492        )
1493    }
1494
1495    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1496    async fn test_merklized_state_pruning() {
1497        let db = TmpDb::init().await;
1498        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1499            .await
1500            .unwrap();
1501
1502        let num_blocks = 10_000u64;
1503        let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
1504            MockMerkleTree::new(MockMerkleTree::tree_height());
1505
1506        // Insert entries and merkle nodes for each block height.
1507        let mut tx = storage.write().await.unwrap();
1508        for height in 0..num_blocks {
1509            test_tree.update(height as usize, height as usize).unwrap();
1510
1511            let test_data = serde_json::json!({
1512                MockMerkleTree::header_state_commitment_field():
1513                    serde_json::to_value(test_tree.commitment()).unwrap()
1514            });
1515            tx.upsert(
1516                "header",
1517                [
1518                    "height",
1519                    "hash",
1520                    "payload_hash",
1521                    "timestamp",
1522                    "data",
1523                    "ns_table",
1524                ],
1525                ["height"],
1526                [(
1527                    height as i64,
1528                    format!("hash{height}"),
1529                    "ph".to_string(),
1530                    0,
1531                    test_data,
1532                    "ns".to_string(),
1533                )],
1534            )
1535            .await
1536            .unwrap();
1537
1538            let (_, proof) = test_tree.lookup(height as usize).expect_ok().unwrap();
1539            let traversal_path = <usize as ToTraversalPath<8>>::to_traversal_path(
1540                &(height as usize),
1541                test_tree.height(),
1542            );
1543            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1544                &mut tx,
1545                proof.clone(),
1546                traversal_path,
1547                height,
1548            )
1549            .await
1550            .unwrap();
1551        }
1552        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(
1553            &mut tx,
1554            num_blocks as usize,
1555        )
1556        .await
1557        .unwrap();
1558        tx.commit().await.unwrap();
1559
1560        // Prune up to height 500, keeping only the newest version of each node.
1561        let prune_height = 5678u64;
1562        let mut tx = storage.write().await.unwrap();
1563        tx.delete_state_batch(vec!["test_tree".to_string()], prune_height)
1564            .await
1565            .unwrap();
1566        tx.commit().await.unwrap();
1567
1568        // Verify no paths have multiple versions at or below the prune height.
1569        let mut tx = storage.read().await.unwrap();
1570        let (duplicates,) = query_as::<(i64,)>(
1571            "SELECT count(*) FROM (SELECT count(*) FROM test_tree WHERE created <= $1 GROUP BY \
1572             path HAVING count(*) > 1) AS s",
1573        )
1574        .bind(prune_height as i64)
1575        .fetch_one(tx.as_mut())
1576        .await
1577        .unwrap();
1578        assert_eq!(
1579            duplicates, 0,
1580            "found {duplicates} paths with duplicate versions at or below prune height"
1581        );
1582
1583        // Verify get_path still works for the latest snapshot and returns correct proofs.
1584        let commitment = test_tree.commitment();
1585        let mut tx = storage.read().await.unwrap();
1586        for key in 0..num_blocks as usize {
1587            let proof = MerklizedStateStorage::<MockTypes, MockMerkleTree, 8>::get_path(
1588                &mut tx,
1589                Snapshot::Index(num_blocks - 1),
1590                key,
1591            )
1592            .await
1593            .unwrap_or_else(|e| panic!("get_path failed for key {key} after pruning: {e:#}"));
1594            assert_eq!(
1595                proof.elem(),
1596                Some(&key),
1597                "proof for key {key} has wrong element: {:?}",
1598                proof.elem()
1599            );
1600            MockMerkleTree::verify(commitment, key, &proof)
1601                .unwrap()
1602                .unwrap();
1603        }
1604    }
1605
1606    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1607    async fn test_minimum_retention_pruning() {
1608        let db = TmpDb::init().await;
1609
1610        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1611            .await
1612            .unwrap();
1613        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1614            &TestValidatedState::default(),
1615            &TestInstanceState::default(),
1616            TEST_VERSIONS.test,
1617        )
1618        .await;
1619        // insert some mock data
1620        for i in 0..20 {
1621            leaf.leaf.block_header_mut().block_number = i;
1622            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1623            let mut tx = storage.write().await.unwrap();
1624            tx.insert_leaf(&leaf).await.unwrap();
1625            tx.commit().await.unwrap();
1626        }
1627
1628        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1629        let cfg = PrunerCfg::new();
1630        // Set pruning_threshold to 1
1631        // SQL storage size is more than 1000 bytes even without any data indexed
1632        // This would mean that the threshold would always be greater than the disk usage
1633        // However, minimum retention is set to 24 hours by default so the data would not be pruned
1634        storage.set_pruning_config(cfg.clone().with_pruning_threshold(1));
1635        println!("{:?}", storage.get_pruning_config().unwrap());
1636        // Pruning would not delete any data
1637        // All the data is younger than minimum retention period even though the usage > threshold
1638        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1639        // Vacuum the database to reclaim space.
1640        // This is necessary to ensure the test passes.
1641        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1642        vacuum(&storage).await;
1643
1644        // Pruned height should be none
1645        assert!(pruned_height.is_none());
1646
1647        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1648
1649        assert_eq!(
1650            height_after_pruning, height_before_pruning,
1651            "some data has been pruned"
1652        );
1653
1654        // Change minimum retention to 1s
1655        storage.set_pruning_config(
1656            cfg.with_minimum_retention(Duration::from_secs(1))
1657                .with_pruning_threshold(1),
1658        );
1659        // sleep for 2s to make sure the data is older than minimum retention
1660        sleep(Duration::from_secs(2)).await;
1661        // This would prune all the data
1662        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1663        // Vacuum the database to reclaim space.
1664        // This is necessary to ensure the test passes.
1665        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1666        vacuum(&storage).await;
1667
1668        // Pruned height should be some
1669        assert!(pruned_height.is_some());
1670        // All the tables should be empty
1671        // counting rows in header table
1672        let header_rows = storage
1673            .read()
1674            .await
1675            .unwrap()
1676            .fetch_one("select count(*) as count from header")
1677            .await
1678            .unwrap()
1679            .get::<i64, _>("count");
1680        // the table should be empty
1681        assert_eq!(header_rows, 0);
1682    }
1683
1684    #[tokio::test]
1685    #[test_log::test]
1686    async fn test_payload_pruning() {
1687        let db = TmpDb::init().await;
1688        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1689            .await
1690            .unwrap();
1691        storage.set_pruning_config(Default::default());
1692
1693        // Insert some mock data.
1694        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1695            &TestValidatedState::default(),
1696            &TestInstanceState::default(),
1697            TEST_VERSIONS.test,
1698        )
1699        .await;
1700        let block = BlockQueryData::<MockTypes>::genesis(
1701            &Default::default(),
1702            &Default::default(),
1703            TEST_VERSIONS.test.base,
1704        )
1705        .await;
1706        let vid = VidCommonQueryData::<MockTypes>::genesis(
1707            &Default::default(),
1708            &Default::default(),
1709            TEST_VERSIONS.test.base,
1710        )
1711        .await;
1712        {
1713            let mut tx = storage.write().await.unwrap();
1714            tx.insert_leaf(&leaf).await.unwrap();
1715            tx.insert_block(&block).await.unwrap();
1716            tx.insert_vid(&vid, None).await.unwrap();
1717            tx.commit().await.unwrap();
1718        }
1719
1720        // Insert a second leaf sharing the same payload.
1721        leaf.leaf.block_header_mut().block_number += 1;
1722        {
1723            let mut tx = storage.write().await.unwrap();
1724            tx.insert_leaf(&leaf).await.unwrap();
1725            tx.commit().await.unwrap();
1726        }
1727        {
1728            let mut tx = storage.read().await.unwrap();
1729            let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1730                .fetch_one(tx.as_mut())
1731                .await
1732                .unwrap();
1733            assert_eq!(num_payloads, 1);
1734            let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1735                .fetch_one(tx.as_mut())
1736                .await
1737                .unwrap();
1738            assert_eq!(num_vid, 1);
1739        }
1740
1741        // Prune the first leaf but not the second (and thus not the payload or VID).
1742        let pruned_height = storage
1743            .prune(&mut Pruner {
1744                pruned_height: None,
1745                target_height: Some(0),
1746                minimum_retention_height: None,
1747            })
1748            .await
1749            .unwrap();
1750        tracing::info!(?pruned_height, "first pruning run complete");
1751        {
1752            let mut tx = storage.read().await.unwrap();
1753
1754            // First block is pruned.
1755            let err = tx
1756                .get_block(BlockId::<MockTypes>::Number(0))
1757                .await
1758                .unwrap_err();
1759            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1760            let err = tx
1761                .get_vid_common(BlockId::<MockTypes>::Number(0))
1762                .await
1763                .unwrap_err();
1764            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1765
1766            // Second block is still available.
1767            assert_eq!(
1768                tx.get_block(BlockId::<MockTypes>::Number(1)).await.unwrap(),
1769                BlockQueryData::new(leaf.header().clone(), block.payload)
1770            );
1771            assert_eq!(
1772                tx.get_vid_common(BlockId::<MockTypes>::Number(1))
1773                    .await
1774                    .unwrap(),
1775                VidCommonQueryData::new(leaf.header().clone(), vid.common)
1776            );
1777
1778            let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1779                .fetch_one(tx.as_mut())
1780                .await
1781                .unwrap();
1782            assert_eq!(num_payloads, 1);
1783
1784            let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1785                .fetch_one(tx.as_mut())
1786                .await
1787                .unwrap();
1788            assert_eq!(num_vid, 1);
1789        }
1790
1791        // Now prune the second leaf, ensuring the payload and VID get deleted as well.
1792        let pruned_height = storage
1793            .prune(&mut Pruner {
1794                pruned_height,
1795                target_height: Some(1),
1796                minimum_retention_height: None,
1797            })
1798            .await
1799            .unwrap();
1800        tracing::info!(?pruned_height, "second pruning run complete");
1801
1802        let mut tx = storage.read().await.unwrap();
1803        for i in 0..2 {
1804            let err = tx
1805                .get_block(BlockId::<MockTypes>::Number(i))
1806                .await
1807                .unwrap_err();
1808            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1809
1810            let err = tx
1811                .get_vid_common(BlockId::<MockTypes>::Number(i))
1812                .await
1813                .unwrap_err();
1814            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1815        }
1816        let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1817            .fetch_one(tx.as_mut())
1818            .await
1819            .unwrap();
1820        assert_eq!(num_payloads, 0);
1821
1822        let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1823            .fetch_one(tx.as_mut())
1824            .await
1825            .unwrap();
1826        assert_eq!(num_vid, 0);
1827    }
1828
1829    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1830    async fn test_pruned_height_storage() {
1831        let db = TmpDb::init().await;
1832        let cfg = db.config();
1833
1834        let storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1835            .await
1836            .unwrap();
1837        assert!(
1838            storage
1839                .read()
1840                .await
1841                .unwrap()
1842                .load_pruned_height()
1843                .await
1844                .unwrap()
1845                .is_none()
1846        );
1847        for height in [10, 20, 30] {
1848            let mut tx = storage.write().await.unwrap();
1849            tx.save_pruned_height(height).await.unwrap();
1850            tx.commit().await.unwrap();
1851            assert_eq!(
1852                storage
1853                    .read()
1854                    .await
1855                    .unwrap()
1856                    .load_pruned_height()
1857                    .await
1858                    .unwrap(),
1859                Some(height)
1860            );
1861        }
1862    }
1863
1864    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1865    async fn test_transaction_upsert_retries() {
1866        let db = TmpDb::init().await;
1867        let config = db.config();
1868
1869        let storage = SqlStorage::connect(config, StorageConnectionType::Query)
1870            .await
1871            .unwrap();
1872
1873        let mut tx = storage.write().await.unwrap();
1874
1875        // Try to upsert into a table that does not exist.
1876        // This will fail, so our `upsert` function will enter the retry loop.
1877        // Since the table does not exist, all retries will eventually
1878        // fail and we expect an error to be returned.
1879        //
1880        // Previously, this case would cause  a panic because we were calling
1881        // methods on `QueryBuilder` after `.build()` without first
1882        // calling `.reset()`and according to the sqlx docs, that always panics.
1883        // Now, since we are properly calling `.reset()` inside `upsert()` for
1884        // the query builder, the function returns an error instead of panicking.
1885        tx.upsert("does_not_exist", ["test"], ["test"], [(1_i64,)])
1886            .await
1887            .unwrap_err();
1888    }
1889}