Skip to main content

hotshot_query_service/data_source/storage/
sql.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(feature = "sql-data-source")]
14use std::{cmp::min, fmt::Debug, str::FromStr, time::Duration};
15
16use async_trait::async_trait;
17use chrono::Utc;
18#[cfg(not(feature = "embedded-db"))]
19use futures::future::FutureExt;
20use hotshot_types::{
21    data::VidShare,
22    traits::{metrics::Metrics, node_implementation::NodeType},
23};
24use itertools::Itertools;
25use log::LevelFilter;
26#[cfg(not(feature = "embedded-db"))]
27use sqlx::postgres::{PgConnectOptions, PgSslMode};
28#[cfg(feature = "embedded-db")]
29use sqlx::sqlite::SqliteConnectOptions;
30use sqlx::{
31    ConnectOptions, Row,
32    pool::{Pool, PoolOptions},
33};
34
35use crate::{
36    Header, QueryError, QueryResult,
37    availability::{QueryableHeader, QueryablePayload, VidCommonMetadata, VidCommonQueryData},
38    data_source::{
39        VersionedDataSource,
40        storage::pruning::{PruneStorage, PrunerCfg, PrunerConfig},
41        update::Transaction as _,
42    },
43    metrics::PrometheusMetrics,
44    node::BlockId,
45    status::HasMetrics,
46};
47pub extern crate sqlx;
48pub use sqlx::{Database, Sqlite};
49
50mod db;
51mod migrate;
52mod queries;
53mod transaction;
54
55pub use anyhow::Error;
56pub use db::*;
57pub use include_dir::include_dir;
58pub use queries::QueryBuilder;
59pub use refinery::Migration;
60pub use transaction::*;
61
62use self::{migrate::Migrator, transaction::PoolMetrics};
63use super::{AvailabilityStorage, NodeStorage};
64// This needs to be reexported so that we can reference it by absolute path relative to this crate
65// in the expansion of `include_migrations`, even when `include_migrations` is invoked from another
66// crate which doesn't have `include_dir` as a dependency.
67pub use crate::include_migrations;
68
69/// Embed migrations from the given directory into the current binary for PostgreSQL or SQLite.
70///
71/// The macro invocation `include_migrations!(path)` evaluates to an expression of type `impl
72/// Iterator<Item = Migration>`. Each migration must be a text file which is an immediate child of
73/// `path`, and there must be no non-migration files in `path`. The migration files must have names
74/// of the form `V${version}__${name}.sql`, where `version` is a positive integer indicating how the
75/// migration is to be ordered relative to other migrations, and `name` is a descriptive name for
76/// the migration.
77///
78/// `path` should be an absolute path. It is possible to give a path relative to the root of the
79/// invoking crate by using environment variable expansions and the `CARGO_MANIFEST_DIR` environment
80/// variable.
81///
82/// As an example, this is the invocation used to load the default migrations from the
83/// `hotshot-query-service` crate. The migrations are located in a directory called `migrations` at
84/// - PostgreSQL migrations are in `/migrations/postgres`.
85/// - SQLite migrations are in `/migrations/sqlite`.
86///
87/// ```
88/// # use hotshot_query_service::data_source::sql::{include_migrations, Migration};
89/// // For PostgreSQL
90/// #[cfg(not(feature = "embedded-db"))]
91///  let mut migrations: Vec<Migration> =
92///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect();
93/// // For SQLite
94/// #[cfg(feature = "embedded-db")]
95/// let mut migrations: Vec<Migration> =
96///     include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect();
97///
98///     migrations.sort();
99///     assert_eq!(migrations[0].version(), 10);
100///     assert_eq!(migrations[0].name(), "init_schema");
101/// ```
102///
103/// Note that a similar macro is available from Refinery:
104/// [embed_migrations](https://docs.rs/refinery/0.8.11/refinery/macro.embed_migrations.html). This
105/// macro differs in that it evaluates to an iterator of [migrations](Migration), making it an
106/// expression macro, while `embed_migrations` is a statement macro that defines a module which
107/// provides access to the embedded migrations only indirectly via a
108/// [`Runner`](https://docs.rs/refinery/0.8.11/refinery/struct.Runner.html). The direct access to
109/// migrations provided by [`include_migrations`] makes this macro easier to use with
110/// [`Config::migrations`], for combining custom migrations with [`default_migrations`].
111#[macro_export]
112macro_rules! include_migrations {
113    ($dir:tt) => {
114        $crate::data_source::storage::sql::include_dir!($dir)
115            .files()
116            .map(|file| {
117                let path = file.path();
118                let name = path
119                    .file_name()
120                    .and_then(std::ffi::OsStr::to_str)
121                    .unwrap_or_else(|| {
122                        panic!(
123                            "migration file {} must have a non-empty UTF-8 name",
124                            path.display()
125                        )
126                    });
127                let sql = file
128                    .contents_utf8()
129                    .unwrap_or_else(|| panic!("migration file {name} must use UTF-8 encoding"));
130                $crate::data_source::storage::sql::Migration::unapplied(name, sql)
131                    .expect("invalid migration")
132            })
133    };
134}
135
136/// The migrations required to build the default schema for this version of [`SqlStorage`].
137pub fn default_migrations() -> Vec<Migration> {
138    #[cfg(not(feature = "embedded-db"))]
139    let mut migrations =
140        include_migrations!("$CARGO_MANIFEST_DIR/migrations/postgres").collect::<Vec<_>>();
141
142    #[cfg(feature = "embedded-db")]
143    let mut migrations =
144        include_migrations!("$CARGO_MANIFEST_DIR/migrations/sqlite").collect::<Vec<_>>();
145
146    // Check version uniqueness and sort by version.
147    validate_migrations(&mut migrations).expect("default migrations are invalid");
148
149    // Check that all migration versions are multiples of 100, so that custom migrations can be
150    // inserted in between.
151    for m in &migrations {
152        if m.version() <= 30 {
153            // An older version of this software used intervals of 10 instead of 100. This was
154            // changed to allow more custom migrations between each default migration, but we must
155            // still accept older migrations that followed the older rule.
156            assert!(
157                m.version() > 0 && m.version() % 10 == 0,
158                "legacy default migration version {} is not a positive multiple of 10",
159                m.version()
160            );
161        } else {
162            assert!(
163                m.version() % 100 == 0,
164                "default migration version {} is not a multiple of 100",
165                m.version()
166            );
167        }
168    }
169
170    migrations
171}
172
173/// Validate and preprocess a sequence of migrations.
174///
175/// * Ensure all migrations have distinct versions
176/// * Ensure migrations are sorted by increasing version
177fn validate_migrations(migrations: &mut [Migration]) -> Result<(), Error> {
178    migrations.sort_by_key(|m| m.version());
179
180    // Check version uniqueness.
181    for (prev, next) in migrations.iter().zip(migrations.iter().skip(1)) {
182        if next <= prev {
183            return Err(Error::msg(format!(
184                "migration versions are not strictly increasing ({prev}->{next})"
185            )));
186        }
187    }
188
189    Ok(())
190}
191
192/// Add custom migrations to a default migration sequence.
193///
194/// Migrations in `custom` replace migrations in `default` with the same version. Otherwise, the two
195/// sequences `default` and `custom` are merged so that the resulting sequence is sorted by
196/// ascending version number. Each of `default` and `custom` is assumed to be the output of
197/// [`validate_migrations`]; that is, each is sorted by version and contains no duplicate versions.
198fn add_custom_migrations(
199    default: impl IntoIterator<Item = Migration>,
200    custom: impl IntoIterator<Item = Migration>,
201) -> impl Iterator<Item = Migration> {
202    default
203        .into_iter()
204        // Merge sorted lists, joining pairs of equal version into `EitherOrBoth::Both`.
205        .merge_join_by(custom, |l, r| l.version().cmp(&r.version()))
206        // Prefer the custom migration for a given version when both default and custom versions
207        // are present.
208        .map(|pair| pair.reduce(|_, custom| custom))
209}
210
211#[derive(Clone)]
212pub struct Config {
213    #[cfg(feature = "embedded-db")]
214    db_opt: SqliteConnectOptions,
215
216    #[cfg(not(feature = "embedded-db"))]
217    db_opt: PgConnectOptions,
218
219    pool_opt: PoolOptions<Db>,
220
221    /// Extra pool_opt to allow separately configuring the connection pool for query service
222    #[cfg(not(feature = "embedded-db"))]
223    pool_opt_query: PoolOptions<Db>,
224
225    #[cfg(not(feature = "embedded-db"))]
226    schema: String,
227    reset: bool,
228    migrations: Vec<Migration>,
229    no_migrations: bool,
230    pruner_cfg: Option<PrunerCfg>,
231    archive: bool,
232    pool: Option<Pool<Db>>,
233}
234
235#[cfg(not(feature = "embedded-db"))]
236impl Default for Config {
237    fn default() -> Self {
238        PgConnectOptions::default()
239            .username("postgres")
240            .password("password")
241            .host("localhost")
242            .port(5432)
243            .into()
244    }
245}
246
247#[cfg(feature = "embedded-db")]
248impl Default for Config {
249    fn default() -> Self {
250        SqliteConnectOptions::default()
251            .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
252            .busy_timeout(Duration::from_secs(30))
253            .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Incremental)
254            .create_if_missing(true)
255            .into()
256    }
257}
258
259#[cfg(feature = "embedded-db")]
260impl From<SqliteConnectOptions> for Config {
261    fn from(db_opt: SqliteConnectOptions) -> Self {
262        Self {
263            db_opt,
264            pool_opt: PoolOptions::default(),
265            reset: false,
266            migrations: vec![],
267            no_migrations: false,
268            pruner_cfg: None,
269            archive: false,
270            pool: None,
271        }
272    }
273}
274
275#[cfg(not(feature = "embedded-db"))]
276impl From<PgConnectOptions> for Config {
277    fn from(db_opt: PgConnectOptions) -> Self {
278        Self {
279            db_opt,
280            pool_opt: PoolOptions::default(),
281            pool_opt_query: PoolOptions::default(),
282            schema: "hotshot".into(),
283            reset: false,
284            migrations: vec![],
285            no_migrations: false,
286            pruner_cfg: None,
287            archive: false,
288            pool: None,
289        }
290    }
291}
292
293#[cfg(not(feature = "embedded-db"))]
294impl FromStr for Config {
295    type Err = <PgConnectOptions as FromStr>::Err;
296
297    fn from_str(s: &str) -> Result<Self, Self::Err> {
298        Ok(PgConnectOptions::from_str(s)?.into())
299    }
300}
301
302#[cfg(feature = "embedded-db")]
303impl FromStr for Config {
304    type Err = <SqliteConnectOptions as FromStr>::Err;
305
306    fn from_str(s: &str) -> Result<Self, Self::Err> {
307        Ok(SqliteConnectOptions::from_str(s)?.into())
308    }
309}
310
311#[cfg(feature = "embedded-db")]
312impl Config {
313    pub fn busy_timeout(mut self, timeout: Duration) -> Self {
314        self.db_opt = self.db_opt.busy_timeout(timeout);
315        self
316    }
317
318    pub fn db_path(mut self, path: std::path::PathBuf) -> Self {
319        self.db_opt = self.db_opt.filename(path);
320        self
321    }
322}
323
324#[cfg(not(feature = "embedded-db"))]
325impl Config {
326    /// Set the hostname of the database server.
327    ///
328    /// The default is `localhost`.
329    pub fn host(mut self, host: impl Into<String>) -> Self {
330        self.db_opt = self.db_opt.host(&host.into());
331        self
332    }
333
334    /// Set the port on which to connect to the database.
335    ///
336    /// The default is 5432, the default Postgres port.
337    pub fn port(mut self, port: u16) -> Self {
338        self.db_opt = self.db_opt.port(port);
339        self
340    }
341
342    /// Set the DB user to connect as.
343    pub fn user(mut self, user: &str) -> Self {
344        self.db_opt = self.db_opt.username(user);
345        self
346    }
347
348    /// Set a password for connecting to the database.
349    pub fn password(mut self, password: &str) -> Self {
350        self.db_opt = self.db_opt.password(password);
351        self
352    }
353
354    /// Set the name of the database to connect to.
355    pub fn database(mut self, database: &str) -> Self {
356        self.db_opt = self.db_opt.database(database);
357        self
358    }
359
360    /// Use TLS for an encrypted connection to the database.
361    ///
362    /// Note that an encrypted connection may be established even if this option is not set, as long
363    /// as both the client and server support it. This option merely causes connection to fail if an
364    /// encrypted stream cannot be established.
365    pub fn tls(mut self) -> Self {
366        self.db_opt = self.db_opt.ssl_mode(PgSslMode::Require);
367        self
368    }
369
370    /// Set the name of the schema to use for queries.
371    ///
372    /// The default schema is named `hotshot` and is created via the default migrations.
373    pub fn schema(mut self, schema: impl Into<String>) -> Self {
374        self.schema = schema.into();
375        self
376    }
377}
378
379impl Config {
380    /// Sets the database connection pool
381    /// This allows reusing an existing connection pool when building a new `SqlStorage` instance.
382    pub fn pool(mut self, pool: Pool<Db>) -> Self {
383        self.pool = Some(pool);
384        self
385    }
386
387    /// Reset the schema on connection.
388    ///
389    /// When this [`Config`] is used to [`connect`](Self::connect) a
390    /// [`SqlDataSource`](crate::data_source::SqlDataSource), if this option is set, the relevant
391    /// [`schema`](Self::schema) will first be dropped and then recreated, yielding a completely
392    /// fresh instance of the query service.
393    ///
394    /// This is a particularly useful capability for development and staging environments. Still, it
395    /// must be used with extreme caution, as using this will irrevocably delete any data pertaining
396    /// to the query service in the database.
397    pub fn reset_schema(mut self) -> Self {
398        self.reset = true;
399        self
400    }
401
402    /// Add custom migrations to run when connecting to the database.
403    pub fn migrations(mut self, migrations: impl IntoIterator<Item = Migration>) -> Self {
404        self.migrations.extend(migrations);
405        self
406    }
407
408    /// Skip all migrations when connecting to the database.
409    pub fn no_migrations(mut self) -> Self {
410        self.no_migrations = true;
411        self
412    }
413
414    /// Enable pruning with a given configuration.
415    ///
416    /// If [`archive`](Self::archive) was previously specified, this will override it.
417    pub fn pruner_cfg(mut self, cfg: PrunerCfg) -> Result<Self, Error> {
418        cfg.validate()?;
419        self.pruner_cfg = Some(cfg);
420        self.archive = false;
421        Ok(self)
422    }
423
424    /// Disable pruning and reconstruct previously pruned data.
425    ///
426    /// While running without pruning is the default behavior, the default will not try to
427    /// reconstruct data that was pruned in a previous run where pruning was enabled. This option
428    /// instructs the service to run without pruning _and_ reconstruct all previously pruned data by
429    /// fetching from peers.
430    ///
431    /// If [`pruner_cfg`](Self::pruner_cfg) was previously specified, this will override it.
432    pub fn archive(mut self) -> Self {
433        self.pruner_cfg = None;
434        self.archive = true;
435        self
436    }
437
438    /// Set the maximum idle time of a connection.
439    ///
440    /// Any connection which has been open and unused longer than this duration will be
441    /// automatically closed to reduce load on the server.
442    pub fn idle_connection_timeout(mut self, timeout: Duration) -> Self {
443        self.pool_opt = self.pool_opt.idle_timeout(Some(timeout));
444
445        #[cfg(not(feature = "embedded-db"))]
446        {
447            self.pool_opt_query = self.pool_opt_query.idle_timeout(Some(timeout));
448        }
449
450        self
451    }
452
453    /// Set the maximum lifetime of a connection.
454    ///
455    /// Any connection which has been open longer than this duration will be automatically closed
456    /// (and, if needed, replaced), even if it is otherwise healthy. It is good practice to refresh
457    /// even healthy connections once in a while (e.g. daily) in case of resource leaks in the
458    /// server implementation.
459    pub fn connection_timeout(mut self, timeout: Duration) -> Self {
460        self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
461
462        #[cfg(not(feature = "embedded-db"))]
463        {
464            self.pool_opt = self.pool_opt.max_lifetime(Some(timeout));
465        }
466
467        self
468    }
469
470    /// Set the minimum number of connections to maintain at any time.
471    ///
472    /// The data source will, to the best of its ability, maintain at least `min` open connections
473    /// at all times. This can be used to reduce the latency hit of opening new connections when at
474    /// least this many simultaneous connections are frequently needed.
475    pub fn min_connections(mut self, min: u32) -> Self {
476        self.pool_opt = self.pool_opt.min_connections(min);
477        self
478    }
479
480    #[cfg(not(feature = "embedded-db"))]
481    pub fn query_min_connections(mut self, min: u32) -> Self {
482        self.pool_opt_query = self.pool_opt_query.min_connections(min);
483        self
484    }
485
486    /// Set the maximum number of connections to maintain at any time.
487    ///
488    /// Once `max` connections are in use simultaneously, further attempts to acquire a connection
489    /// (or begin a transaction) will block until one of the existing connections is released.
490    pub fn max_connections(mut self, max: u32) -> Self {
491        self.pool_opt = self.pool_opt.max_connections(max);
492        self
493    }
494
495    #[cfg(not(feature = "embedded-db"))]
496    pub fn query_max_connections(mut self, max: u32) -> Self {
497        self.pool_opt_query = self.pool_opt_query.max_connections(max);
498        self
499    }
500
501    /// Log at WARN level any time a SQL statement takes longer than `threshold`.
502    ///
503    /// The default threshold is 1s.
504    pub fn slow_statement_threshold(mut self, threshold: Duration) -> Self {
505        self.db_opt = self
506            .db_opt
507            .log_slow_statements(LevelFilter::Warn, threshold);
508        self
509    }
510
511    /// Set the maximum time a single SQL statement is allowed to run before being canceled.
512    ///
513    /// This helps prevent queries from running indefinitely even when the client is dropped
514    #[cfg(not(feature = "embedded-db"))]
515    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
516        // Format duration as milliseconds
517        // PostgreSQL interprets values without units as milliseconds
518        let timeout_ms = timeout.as_millis();
519        self.db_opt = self
520            .db_opt
521            .options([("statement_timeout", timeout_ms.to_string())]);
522        self
523    }
524
525    /// not supported for SQLite.
526    #[cfg(feature = "embedded-db")]
527    pub fn statement_timeout(self, _timeout: Duration) -> Self {
528        self
529    }
530}
531
532/// Storage for the APIs provided in this crate, backed by a remote PostgreSQL database.
533#[derive(Clone, Debug)]
534pub struct SqlStorage {
535    pool: Pool<Db>,
536    metrics: PrometheusMetrics,
537    pool_metrics: PoolMetrics,
538    pruner_cfg: Option<PrunerCfg>,
539}
540
541#[derive(Debug, Default)]
542pub struct Pruner {
543    pruned_height: Option<u64>,
544    target_height: Option<u64>,
545    minimum_retention_height: Option<u64>,
546}
547
548#[derive(PartialEq)]
549pub enum StorageConnectionType {
550    Sequencer,
551    Query,
552}
553
554impl SqlStorage {
555    pub fn pool(&self) -> Pool<Db> {
556        self.pool.clone()
557    }
558
559    /// Connect to a remote database.
560    #[allow(unused_variables)]
561    pub async fn connect(
562        mut config: Config,
563        connection_type: StorageConnectionType,
564    ) -> Result<Self, Error> {
565        let metrics = PrometheusMetrics::default();
566        let pool_metrics = PoolMetrics::new(&*metrics.subgroup("sql".into()));
567
568        #[cfg(feature = "embedded-db")]
569        let pool = config.pool_opt.clone();
570        #[cfg(not(feature = "embedded-db"))]
571        let pool = match connection_type {
572            StorageConnectionType::Sequencer => config.pool_opt.clone(),
573            StorageConnectionType::Query => config.pool_opt_query.clone(),
574        };
575
576        let pruner_cfg = config.pruner_cfg;
577
578        // Only reuse the same pool if we're using sqlite
579        if cfg!(feature = "embedded-db") || connection_type == StorageConnectionType::Sequencer {
580            // re-use the same pool if present and return early
581            if let Some(pool) = config.pool {
582                return Ok(Self {
583                    metrics,
584                    pool_metrics,
585                    pool,
586                    pruner_cfg,
587                });
588            }
589        } else if config.pool.is_some() {
590            tracing::info!("not reusing existing pool for query connection");
591        }
592
593        #[cfg(not(feature = "embedded-db"))]
594        let schema = config.schema.clone();
595        #[cfg(not(feature = "embedded-db"))]
596        let pool = pool.after_connect(move |conn, _| {
597            let schema = config.schema.clone();
598            async move {
599                query(&format!("SET search_path TO {schema}"))
600                    .execute(conn)
601                    .await?;
602                Ok(())
603            }
604            .boxed()
605        });
606
607        #[cfg(feature = "embedded-db")]
608        if config.reset {
609            std::fs::remove_file(config.db_opt.get_filename())?;
610        }
611
612        let pool = pool.connect_with(config.db_opt).await?;
613
614        // Create or connect to the schema for this query service.
615        let mut conn = pool.acquire().await?;
616
617        // Disable statement timeout for migrations, as they can take a long time
618        #[cfg(not(feature = "embedded-db"))]
619        query("SET statement_timeout = 0")
620            .execute(conn.as_mut())
621            .await?;
622
623        #[cfg(not(feature = "embedded-db"))]
624        if config.reset {
625            query(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"))
626                .execute(conn.as_mut())
627                .await?;
628        }
629
630        #[cfg(not(feature = "embedded-db"))]
631        query(&format!("CREATE SCHEMA IF NOT EXISTS {schema}"))
632            .execute(conn.as_mut())
633            .await?;
634
635        // Get migrations and interleave with custom migrations, sorting by version number.
636        validate_migrations(&mut config.migrations)?;
637        let migrations =
638            add_custom_migrations(default_migrations(), config.migrations).collect::<Vec<_>>();
639
640        // Get a migration runner. Depending on the config, we can either use this to actually run
641        // the migrations or just check if the database is up to date.
642        let runner = refinery::Runner::new(&migrations).set_grouped(true);
643
644        if config.no_migrations {
645            // We've been asked not to run any migrations. Abort if the DB is not already up to
646            // date.
647            let last_applied = runner
648                .get_last_applied_migration_async(&mut Migrator::from(&mut conn))
649                .await?;
650            let last_expected = migrations.last();
651            if last_applied.as_ref() != last_expected {
652                return Err(Error::msg(format!(
653                    "DB is out of date: last applied migration is {last_applied:?}, but expected \
654                     {last_expected:?}"
655                )));
656            }
657        } else {
658            // Run migrations using `refinery`.
659            match runner.run_async(&mut Migrator::from(&mut conn)).await {
660                Ok(report) => {
661                    tracing::info!("ran DB migrations: {report:?}");
662                },
663                Err(err) => {
664                    tracing::error!("DB migrations failed: {:?}", err.report());
665                    Err(err)?;
666                },
667            }
668        }
669
670        if config.archive {
671            // If running in archive mode, ensure the pruned height is set to 0, so the fetcher will
672            // reconstruct previously pruned data.
673            query("DELETE FROM pruned_height WHERE id = 1")
674                .execute(conn.as_mut())
675                .await?;
676        }
677
678        conn.close().await?;
679
680        Ok(Self {
681            pool,
682            pool_metrics,
683            metrics,
684            pruner_cfg,
685        })
686    }
687}
688
689impl PrunerConfig for SqlStorage {
690    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
691        self.pruner_cfg = Some(cfg);
692    }
693
694    fn get_pruning_config(&self) -> Option<PrunerCfg> {
695        self.pruner_cfg.clone()
696    }
697}
698
699impl HasMetrics for SqlStorage {
700    fn metrics(&self) -> &PrometheusMetrics {
701        &self.metrics
702    }
703}
704
705impl SqlStorage {
706    async fn prune_write(&self) -> anyhow::Result<Transaction<Prune>> {
707        Transaction::new(&self.pool, self.pool_metrics.clone()).await
708    }
709
710    async fn get_minimum_height(&self) -> QueryResult<Option<u64>> {
711        let mut tx = self.read().await.map_err(|err| QueryError::Error {
712            message: err.to_string(),
713        })?;
714        let (Some(height),) =
715            query_as::<(Option<i64>,)>("SELECT MIN(height) as height FROM header")
716                .fetch_one(tx.as_mut())
717                .await?
718        else {
719            return Ok(None);
720        };
721        Ok(Some(height as u64))
722    }
723
724    async fn get_height_by_timestamp(&self, timestamp: i64) -> QueryResult<Option<u64>> {
725        let mut tx = self.read().await.map_err(|err| QueryError::Error {
726            message: err.to_string(),
727        })?;
728
729        // We order by timestamp and then height, even though logically this is no different than
730        // just ordering by height, since timestamps are monotonic. The reason is that this order
731        // allows the query planner to efficiently solve the where clause and presort the results
732        // based on the timestamp index. The remaining sort on height, which guarantees a unique
733        // block if multiple blocks have the same timestamp, is very efficient, because there are
734        // never more than a handful of blocks with the same timestamp.
735        let Some((height,)) = query_as::<(i64,)>(
736            "SELECT height FROM header
737              WHERE timestamp <= $1
738              ORDER BY timestamp DESC, height DESC
739              LIMIT 1",
740        )
741        .bind(timestamp)
742        .fetch_optional(tx.as_mut())
743        .await?
744        else {
745            return Ok(None);
746        };
747        Ok(Some(height as u64))
748    }
749
750    /// Get the stored VID share for a given block, if one exists.
751    pub async fn get_vid_share<Types>(&self, block_id: BlockId<Types>) -> QueryResult<VidShare>
752    where
753        Types: NodeType,
754        Header<Types>: QueryableHeader<Types>,
755    {
756        let mut tx = self.read().await.map_err(|err| QueryError::Error {
757            message: err.to_string(),
758        })?;
759        let share = tx.vid_share(block_id).await?;
760        Ok(share)
761    }
762
763    /// Get the stored VID common data for a given block, if one exists.
764    pub async fn get_vid_common<Types: NodeType>(
765        &self,
766        block_id: BlockId<Types>,
767    ) -> QueryResult<VidCommonQueryData<Types>>
768    where
769        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
770        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
771    {
772        let mut tx = self.read().await.map_err(|err| QueryError::Error {
773            message: err.to_string(),
774        })?;
775        let common = tx.get_vid_common(block_id).await?;
776        Ok(common)
777    }
778
779    /// Get the stored VID common metadata for a given block, if one exists.
780    pub async fn get_vid_common_metadata<Types: NodeType>(
781        &self,
782        block_id: BlockId<Types>,
783    ) -> QueryResult<VidCommonMetadata<Types>>
784    where
785        <Types as NodeType>::BlockPayload: QueryablePayload<Types>,
786        <Types as NodeType>::BlockHeader: QueryableHeader<Types>,
787    {
788        let mut tx = self.read().await.map_err(|err| QueryError::Error {
789            message: err.to_string(),
790        })?;
791        let common = tx.get_vid_common_metadata(block_id).await?;
792        Ok(common)
793    }
794}
795
796#[async_trait]
797impl PruneStorage for SqlStorage {
798    type Pruner = Pruner;
799
800    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
801        let mut tx = self.read().await?;
802
803        #[cfg(not(feature = "embedded-db"))]
804        let query = "SELECT pg_database_size(current_database())";
805
806        #[cfg(feature = "embedded-db")]
807        let query = "
808            SELECT( (SELECT page_count FROM pragma_page_count) * (SELECT * FROM pragma_page_size)) \
809                     AS total_bytes";
810
811        let row = tx.fetch_one(query).await?;
812        let size: i64 = row.get(0);
813
814        Ok(size as u64)
815    }
816
817    /// Trigger incremental vacuum to free up space in the SQLite database.
818    /// Note: We don't vacuum the Postgres database,
819    /// as there is no manual trigger for incremental vacuum,
820    /// and a full vacuum can take a lot of time.
821    #[cfg(feature = "embedded-db")]
822    async fn vacuum(&self) -> anyhow::Result<()> {
823        let config = self.get_pruning_config().ok_or(QueryError::Error {
824            message: "Pruning config not found".to_string(),
825        })?;
826        let mut conn = self.pool().acquire().await?;
827        query(&format!(
828            "PRAGMA incremental_vacuum({})",
829            config.incremental_vacuum_pages()
830        ))
831        .execute(conn.as_mut())
832        .await?;
833        conn.close().await?;
834        Ok(())
835    }
836
837    /// Note: The prune operation may not immediately free up space even after rows are deleted.
838    /// This is because a vacuum operation may be necessary to reclaim more space.
839    /// PostgreSQL already performs auto vacuuming, so we are not including it here
840    /// as running a vacuum operation can be resource-intensive.
841    async fn prune(&self, pruner: &mut Pruner) -> anyhow::Result<Option<u64>> {
842        let cfg = self.get_pruning_config().ok_or(QueryError::Error {
843            message: "Pruning config not found".to_string(),
844        })?;
845        let batch_size = cfg.batch_size();
846        let max_usage = cfg.max_usage();
847        let state_tables = cfg.state_tables();
848
849        // If a pruner run was already in progress, some variables may already be set,
850        // depending on whether a batch was deleted and which batch it was (target or minimum retention).
851        // This enables us to resume the pruner run from the exact heights.
852        // If any of these values are not set, they can be loaded from the database if necessary.
853        let mut minimum_retention_height = pruner.minimum_retention_height;
854        let mut target_height = pruner.target_height;
855        let pruned_height = match pruner.pruned_height {
856            Some(h) => Some(h),
857            None => {
858                let Some(height) = self.get_minimum_height().await? else {
859                    tracing::info!("database is empty, nothing to prune");
860                    return Ok(None);
861                };
862
863                if height > 0 { Some(height - 1) } else { None }
864            },
865        };
866
867        // Prune data exceeding target retention in batches
868        if pruner.target_height.is_none() {
869            let th = self
870                .get_height_by_timestamp(
871                    Utc::now().timestamp() - (cfg.target_retention().as_secs()) as i64,
872                )
873                .await?;
874            target_height = th;
875            pruner.target_height = target_height;
876        };
877
878        if let Some(th) = target_height
879            && pruned_height < Some(th)
880        {
881            let batch_end = match pruned_height {
882                None => batch_size - 1,
883                Some(h) => h + batch_size,
884            };
885            let to = min(batch_end, th);
886
887            // Update pruned height first so the fetcher does not
888            // try to fetch data that we are about to delete.
889            let mut tx = self.write().await?;
890            tx.save_pruned_height(to).await?;
891            tx.commit().await.map_err(|e| QueryError::Error {
892                message: format!("failed to commit save_pruned_height {e}"),
893            })?;
894
895            let mut tx = self.prune_write().await?;
896            tx.delete_batch(to).await?;
897            tx.commit().await.map_err(|e| QueryError::Error {
898                message: format!("failed to commit delete_batch {e}"),
899            })?;
900
901            // Prune state tables in a separate transaction.
902            let mut tx = self.prune_write().await?;
903            tx.delete_state_batch(state_tables, to).await?;
904            tx.commit().await.map_err(|e| QueryError::Error {
905                message: format!("failed to commit {e}"),
906            })?;
907
908            pruner.pruned_height = Some(to);
909            return Ok(Some(to));
910        }
911
912        // If threshold is set, prune data exceeding minimum retention in batches
913        // This parameter is needed for SQL storage as there is no direct way to get free space.
914        if let Some(threshold) = cfg.pruning_threshold() {
915            let usage = self.get_disk_usage().await?;
916
917            // Prune data exceeding minimum retention in batches starting from minimum height
918            // until usage is below threshold
919            if usage > threshold {
920                tracing::warn!(
921                    "Disk usage {usage} exceeds pruning threshold {:?}",
922                    cfg.pruning_threshold()
923                );
924
925                if minimum_retention_height.is_none() {
926                    minimum_retention_height = self
927                        .get_height_by_timestamp(
928                            Utc::now().timestamp() - (cfg.minimum_retention().as_secs()) as i64,
929                        )
930                        .await?;
931
932                    pruner.minimum_retention_height = minimum_retention_height;
933                }
934
935                if let Some(min_retention_height) = minimum_retention_height
936                    && (usage as f64 / threshold as f64) > (f64::from(max_usage) / 10000.0)
937                    && pruned_height < Some(min_retention_height)
938                {
939                    let batch_end = match pruned_height {
940                        None => batch_size - 1,
941                        Some(h) => h + batch_size,
942                    };
943                    let to = min(batch_end, min_retention_height);
944                    // Update pruned height first so the fetcher does not
945                    // try to fetch data that we are about to delete.
946                    let mut tx = self.write().await?;
947                    tx.save_pruned_height(to).await?;
948                    tx.commit().await.map_err(|e| QueryError::Error {
949                        message: format!("failed to commit save_pruned_height {e}"),
950                    })?;
951
952                    let mut tx = self.prune_write().await?;
953                    tx.delete_batch(to).await?;
954                    tx.commit().await.map_err(|e| QueryError::Error {
955                        message: format!("failed to commit delete_batch {e}"),
956                    })?;
957
958                    // Prune state tables in a separate transaction.
959                    let mut tx = self.prune_write().await?;
960                    tx.delete_state_batch(state_tables, to).await?;
961                    tx.commit().await.map_err(|e| QueryError::Error {
962                        message: format!("failed to commit {e}"),
963                    })?;
964
965                    self.vacuum().await?;
966                    pruner.pruned_height = Some(to);
967                    return Ok(Some(to));
968                }
969            }
970        }
971
972        Ok(None)
973    }
974}
975
976impl VersionedDataSource for SqlStorage {
977    type Transaction<'a>
978        = Transaction<Write>
979    where
980        Self: 'a;
981    type ReadOnly<'a>
982        = Transaction<Read>
983    where
984        Self: 'a;
985
986    async fn write(&self) -> anyhow::Result<Transaction<Write>> {
987        Transaction::new(&self.pool, self.pool_metrics.clone()).await
988    }
989
990    async fn read(&self) -> anyhow::Result<Transaction<Read>> {
991        Transaction::new(&self.pool, self.pool_metrics.clone()).await
992    }
993}
994
995// These tests run the `postgres` Docker image, which doesn't work on Windows.
996#[cfg(all(any(test, feature = "testing"), not(target_os = "windows")))]
997pub mod testing {
998    #![allow(unused_imports)]
999    use std::{
1000        env,
1001        process::{Command, Stdio},
1002        str::{self, FromStr},
1003        time::Duration,
1004    };
1005
1006    use refinery::Migration;
1007    use test_utils::reserve_tcp_port;
1008    use tokio::{net::TcpStream, time::timeout};
1009
1010    use super::Config;
1011    use crate::testing::sleep;
1012    #[derive(Debug)]
1013    pub struct TmpDb {
1014        #[cfg(not(feature = "embedded-db"))]
1015        host: String,
1016        #[cfg(not(feature = "embedded-db"))]
1017        port: u16,
1018        #[cfg(not(feature = "embedded-db"))]
1019        container_id: String,
1020        #[cfg(feature = "embedded-db")]
1021        db_path: std::path::PathBuf,
1022        #[allow(dead_code)]
1023        persistent: bool,
1024    }
1025    impl TmpDb {
1026        #[cfg(feature = "embedded-db")]
1027        fn init_sqlite_db(persistent: bool) -> Self {
1028            let file = tempfile::Builder::new()
1029                .prefix("sqlite-")
1030                .suffix(".db")
1031                .tempfile()
1032                .unwrap();
1033
1034            let (_, db_path) = file.keep().unwrap();
1035
1036            Self {
1037                db_path,
1038                persistent,
1039            }
1040        }
1041        pub async fn init() -> Self {
1042            #[cfg(feature = "embedded-db")]
1043            return Self::init_sqlite_db(false);
1044
1045            #[cfg(not(feature = "embedded-db"))]
1046            Self::init_postgres(false).await
1047        }
1048
1049        pub async fn persistent() -> Self {
1050            #[cfg(feature = "embedded-db")]
1051            return Self::init_sqlite_db(true);
1052
1053            #[cfg(not(feature = "embedded-db"))]
1054            Self::init_postgres(true).await
1055        }
1056
1057        #[cfg(not(feature = "embedded-db"))]
1058        async fn init_postgres(persistent: bool) -> Self {
1059            let docker_hostname = env::var("DOCKER_HOSTNAME");
1060            // This picks an unused port on the current system.  If docker is
1061            // configured to run on a different host then this may not find a
1062            // "free" port on that system.
1063            // We *might* be able to get away with this as any remote docker
1064            // host should hopefully be pretty open with it's port space.
1065            let port = reserve_tcp_port().unwrap();
1066            let host = docker_hostname.unwrap_or("localhost".to_string());
1067
1068            let mut cmd = Command::new("docker");
1069            cmd.arg("run")
1070                .arg("-d")
1071                .args(["-p", &format!("{port}:5432")])
1072                .args(["-e", "POSTGRES_PASSWORD=password"]);
1073
1074            if !persistent {
1075                cmd.arg("--rm");
1076            }
1077
1078            let output = cmd.arg("postgres").output().unwrap();
1079            let stdout = str::from_utf8(&output.stdout).unwrap();
1080            let stderr = str::from_utf8(&output.stderr).unwrap();
1081            if !output.status.success() {
1082                panic!("failed to start postgres docker: {stderr}");
1083            }
1084
1085            // Create the TmpDb object immediately after starting the Docker container, so if
1086            // anything panics after this `drop` will be called and we will clean up.
1087            let container_id = stdout.trim().to_owned();
1088            tracing::info!("launched postgres docker {container_id}");
1089            let db = Self {
1090                host,
1091                port,
1092                container_id: container_id.clone(),
1093                persistent,
1094            };
1095
1096            db.wait_for_ready().await;
1097            db
1098        }
1099
1100        #[cfg(not(feature = "embedded-db"))]
1101        pub fn host(&self) -> String {
1102            self.host.clone()
1103        }
1104
1105        #[cfg(not(feature = "embedded-db"))]
1106        pub fn port(&self) -> u16 {
1107            self.port
1108        }
1109
1110        #[cfg(feature = "embedded-db")]
1111        pub fn path(&self) -> std::path::PathBuf {
1112            self.db_path.clone()
1113        }
1114
1115        pub fn config(&self) -> Config {
1116            #[cfg(feature = "embedded-db")]
1117            let mut cfg = Config::default().db_path(self.db_path.clone());
1118
1119            #[cfg(not(feature = "embedded-db"))]
1120            let mut cfg = Config::default()
1121                .user("postgres")
1122                .password("password")
1123                .host(self.host())
1124                .port(self.port());
1125
1126            cfg = cfg.migrations(vec![
1127                Migration::unapplied(
1128                    "V101__create_test_merkle_tree_table.sql",
1129                    &TestMerkleTreeMigration::create("test_tree"),
1130                )
1131                .unwrap(),
1132            ]);
1133
1134            cfg
1135        }
1136
1137        #[cfg(not(feature = "embedded-db"))]
1138        pub fn stop_postgres(&mut self) {
1139            tracing::info!(container = self.container_id, "stopping postgres");
1140            let output = Command::new("docker")
1141                .args(["stop", self.container_id.as_str()])
1142                .output()
1143                .unwrap();
1144            assert!(
1145                output.status.success(),
1146                "error killing postgres docker {}: {}",
1147                self.container_id,
1148                str::from_utf8(&output.stderr).unwrap()
1149            );
1150        }
1151
1152        #[cfg(not(feature = "embedded-db"))]
1153        pub async fn start_postgres(&mut self) {
1154            tracing::info!(container = self.container_id, "resuming postgres");
1155            let output = Command::new("docker")
1156                .args(["start", self.container_id.as_str()])
1157                .output()
1158                .unwrap();
1159            assert!(
1160                output.status.success(),
1161                "error starting postgres docker {}: {}",
1162                self.container_id,
1163                str::from_utf8(&output.stderr).unwrap()
1164            );
1165
1166            self.wait_for_ready().await;
1167        }
1168
1169        #[cfg(not(feature = "embedded-db"))]
1170        async fn wait_for_ready(&self) {
1171            let timeout_duration = Duration::from_secs(
1172                env::var("SQL_TMP_DB_CONNECT_TIMEOUT")
1173                    .unwrap_or("60".to_string())
1174                    .parse()
1175                    .expect("SQL_TMP_DB_CONNECT_TIMEOUT must be an integer number of seconds"),
1176            );
1177
1178            if let Err(err) = timeout(timeout_duration, async {
1179                while Command::new("docker")
1180                    .args([
1181                        "exec",
1182                        &self.container_id,
1183                        "pg_isready",
1184                        "-h",
1185                        "localhost",
1186                        "-U",
1187                        "postgres",
1188                    ])
1189                    .env("PGPASSWORD", "password")
1190                    // Null input so the command terminates as soon as it manages to connect.
1191                    .stdin(Stdio::null())
1192                    // Discard command output.
1193                    .stdout(Stdio::null())
1194                    .stderr(Stdio::null())
1195                    .status()
1196                    // We should ensure the exit status. A simple `unwrap`
1197                    // would panic on unrelated errors (such as network
1198                    // connection failures)
1199                    .and_then(|status| {
1200                        status
1201                            .success()
1202                            .then_some(true)
1203                            // Any ol' Error will do
1204                            .ok_or(std::io::Error::from_raw_os_error(666))
1205                    })
1206                    .is_err()
1207                {
1208                    tracing::warn!("database is not ready");
1209                    sleep(Duration::from_secs(1)).await;
1210                }
1211
1212                // The above command ensures the database is ready inside the Docker container.
1213                // However, on some systems, there is a slight delay before the port is exposed via
1214                // host networking. We don't need to check again that the database is ready on the
1215                // host (and maybe can't, because the host might not have pg_isready installed), but
1216                // we can ensure the port is open by just establishing a TCP connection.
1217                while let Err(err) =
1218                    TcpStream::connect(format!("{}:{}", self.host, self.port)).await
1219                {
1220                    tracing::warn!("database is ready, but port is not available to host: {err:#}");
1221                    sleep(Duration::from_millis(100)).await;
1222                }
1223            })
1224            .await
1225            {
1226                panic!(
1227                    "failed to connect to TmpDb within configured timeout {timeout_duration:?}: \
1228                     {err:#}\n{}",
1229                    "Consider increasing the timeout by setting SQL_TMP_DB_CONNECT_TIMEOUT"
1230                );
1231            }
1232        }
1233    }
1234
1235    #[cfg(not(feature = "embedded-db"))]
1236    impl Drop for TmpDb {
1237        fn drop(&mut self) {
1238            self.stop_postgres();
1239        }
1240    }
1241
1242    #[cfg(feature = "embedded-db")]
1243    impl Drop for TmpDb {
1244        fn drop(&mut self) {
1245            if !self.persistent {
1246                std::fs::remove_file(self.db_path.clone()).unwrap();
1247            }
1248        }
1249    }
1250
1251    pub struct TestMerkleTreeMigration;
1252
1253    impl TestMerkleTreeMigration {
1254        fn create(name: &str) -> String {
1255            let (bit_vec, binary, hash_pk, root_stored_column) = if cfg!(feature = "embedded-db") {
1256                (
1257                    "TEXT",
1258                    "BLOB",
1259                    "INTEGER PRIMARY KEY AUTOINCREMENT",
1260                    " (json_extract(data, '$.test_merkle_tree_root'))",
1261                )
1262            } else {
1263                (
1264                    "BIT(8)",
1265                    "BYTEA",
1266                    "SERIAL PRIMARY KEY",
1267                    "(data->>'test_merkle_tree_root')",
1268                )
1269            };
1270
1271            format!(
1272                "CREATE TABLE IF NOT EXISTS hash
1273            (
1274                id {hash_pk},
1275                value {binary}  NOT NULL UNIQUE
1276            );
1277
1278            ALTER TABLE header
1279            ADD column test_merkle_tree_root text
1280            GENERATED ALWAYS as {root_stored_column} STORED;
1281
1282            CREATE TABLE {name}
1283            (
1284                path JSONB NOT NULL,
1285                created BIGINT NOT NULL,
1286                hash_id INT NOT NULL,
1287                children JSONB,
1288                children_bitvec {bit_vec},
1289                idx JSONB,
1290                entry JSONB,
1291                PRIMARY KEY (path, created)
1292            );
1293            CREATE INDEX {name}_created ON {name} (created);"
1294            )
1295        }
1296    }
1297}
1298
1299// These tests run the `postgres` Docker image, which doesn't work on Windows.
1300#[cfg(all(test, not(target_os = "windows")))]
1301mod test {
1302    use std::time::Duration;
1303
1304    use hotshot_example_types::{
1305        node_types::TEST_VERSIONS,
1306        state_types::{TestInstanceState, TestValidatedState},
1307    };
1308    use jf_merkle_tree_compat::{
1309        MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme, prelude::UniversalMerkleTree,
1310    };
1311    use tokio::time::sleep;
1312
1313    use super::{testing::TmpDb, *};
1314    use crate::{
1315        availability::{BlockQueryData, LeafQueryData},
1316        data_source::storage::{
1317            MerklizedStateStorage, UpdateAvailabilityStorage, pruning::PrunedHeightStorage,
1318        },
1319        merklized_state::{MerklizedState, Snapshot, UpdateStateData},
1320        testing::mocks::{MockMerkleTree, MockTypes},
1321    };
1322
1323    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1324    async fn test_migrations() {
1325        let db = TmpDb::init().await;
1326        let cfg = db.config();
1327
1328        let connect = |migrations: bool, custom_migrations| {
1329            let cfg = cfg.clone();
1330            async move {
1331                let mut cfg = cfg.migrations(custom_migrations);
1332                if !migrations {
1333                    cfg = cfg.no_migrations();
1334                }
1335                let client = SqlStorage::connect(cfg, StorageConnectionType::Query).await?;
1336                Ok::<_, Error>(client)
1337            }
1338        };
1339
1340        // Connecting with migrations disabled should fail if the database is not already up to date
1341        // (since we've just created a fresh database, it isn't).
1342        let err = connect(false, vec![]).await.unwrap_err();
1343        tracing::info!("connecting without running migrations failed as expected: {err}");
1344
1345        // Now connect and run migrations to bring the database up to date.
1346        connect(true, vec![]).await.unwrap();
1347        // Now connecting without migrations should work.
1348        connect(false, vec![]).await.unwrap();
1349
1350        // Connect with some custom migrations, to advance the schema even further. Pass in the
1351        // custom migrations out of order; they should still execute in order of version number.
1352        // The SQL commands used here will fail if not run in order.
1353        let migrations = vec![
1354            Migration::unapplied(
1355                "V9999__create_test_table.sql",
1356                "ALTER TABLE test ADD COLUMN data INTEGER;",
1357            )
1358            .unwrap(),
1359            Migration::unapplied(
1360                "V9998__create_test_table.sql",
1361                "CREATE TABLE test (x bigint);",
1362            )
1363            .unwrap(),
1364        ];
1365        connect(true, migrations.clone()).await.unwrap();
1366
1367        // Connect using the default schema (no custom migrations) and not running migrations. This
1368        // should fail because the database is _ahead_ of the client in terms of schema.
1369        let err = connect(false, vec![]).await.unwrap_err();
1370        tracing::info!("connecting without running migrations failed as expected: {err}");
1371
1372        // Connecting with the customized schema should work even without running migrations.
1373        connect(true, migrations).await.unwrap();
1374    }
1375
1376    #[test]
1377    #[cfg(not(feature = "embedded-db"))]
1378    fn test_config_from_str() {
1379        let cfg = Config::from_str("postgresql://user:password@host:8080").unwrap();
1380        assert_eq!(cfg.db_opt.get_username(), "user");
1381        assert_eq!(cfg.db_opt.get_host(), "host");
1382        assert_eq!(cfg.db_opt.get_port(), 8080);
1383    }
1384
1385    #[test]
1386    #[cfg(feature = "embedded-db")]
1387    fn test_config_from_str() {
1388        let cfg = Config::from_str("sqlite://data.db").unwrap();
1389        assert_eq!(cfg.db_opt.get_filename().to_string_lossy(), "data.db");
1390    }
1391
1392    async fn vacuum(storage: &SqlStorage) {
1393        #[cfg(feature = "embedded-db")]
1394        let query = "PRAGMA incremental_vacuum(16000)";
1395        #[cfg(not(feature = "embedded-db"))]
1396        let query = "VACUUM";
1397        storage
1398            .pool
1399            .acquire()
1400            .await
1401            .unwrap()
1402            .execute(query)
1403            .await
1404            .unwrap();
1405    }
1406
1407    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1408    async fn test_target_period_pruning() {
1409        let db = TmpDb::init().await;
1410        let cfg = db.config();
1411
1412        let mut storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1413            .await
1414            .unwrap();
1415        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1416            &TestValidatedState::default(),
1417            &TestInstanceState::default(),
1418            TEST_VERSIONS.test,
1419        )
1420        .await;
1421        // insert some mock data
1422        for i in 0..20 {
1423            leaf.leaf.block_header_mut().block_number = i;
1424            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1425            let mut tx = storage.write().await.unwrap();
1426            tx.insert_leaf(&leaf).await.unwrap();
1427            tx.commit().await.unwrap();
1428        }
1429
1430        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1431
1432        // Set pruner config to default which has minimum retention set to 1 day
1433        storage.set_pruning_config(PrunerCfg::new());
1434        // No data will be pruned
1435        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1436
1437        // Vacuum the database to reclaim space.
1438        // This is necessary to ensure the test passes.
1439        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1440        vacuum(&storage).await;
1441        // Pruned height should be none
1442        assert!(pruned_height.is_none());
1443
1444        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1445
1446        assert_eq!(
1447            height_after_pruning, height_before_pruning,
1448            "some data has been pruned"
1449        );
1450
1451        // Set pruner config to target retention set to 1s
1452        storage.set_pruning_config(PrunerCfg::new().with_target_retention(Duration::from_secs(1)));
1453        sleep(Duration::from_secs(2)).await;
1454        let usage_before_pruning = storage.get_disk_usage().await.unwrap();
1455        // All of the data is now older than 1s.
1456        // This would prune all the data as the target retention is set to 1s
1457        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1458        // Vacuum the database to reclaim space.
1459        // This is necessary to ensure the test passes.
1460        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1461        vacuum(&storage).await;
1462
1463        // Pruned height should be some
1464        assert!(pruned_height.is_some());
1465        let usage_after_pruning = storage.get_disk_usage().await.unwrap();
1466        // All the tables should be empty
1467        // counting rows in header table
1468        let header_rows = storage
1469            .read()
1470            .await
1471            .unwrap()
1472            .fetch_one("select count(*) as count from header")
1473            .await
1474            .unwrap()
1475            .get::<i64, _>("count");
1476        // the table should be empty
1477        assert_eq!(header_rows, 0);
1478
1479        // counting rows in leaf table.
1480        // Deleting rows from header table would delete rows in all the tables
1481        // as each of table implement "ON DELETE CASCADE" fk constraint with the header table.
1482        let leaf_rows = storage
1483            .read()
1484            .await
1485            .unwrap()
1486            .fetch_one("select count(*) as count from leaf")
1487            .await
1488            .unwrap()
1489            .get::<i64, _>("count");
1490        // the table should be empty
1491        assert_eq!(leaf_rows, 0);
1492
1493        assert!(
1494            usage_before_pruning > usage_after_pruning,
1495            " disk usage should decrease after pruning"
1496        )
1497    }
1498
1499    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1500    async fn test_merklized_state_pruning() {
1501        let db = TmpDb::init().await;
1502        let storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1503            .await
1504            .unwrap();
1505
1506        let num_blocks = 10_000u64;
1507        let mut test_tree: UniversalMerkleTree<_, _, _, 8, _> =
1508            MockMerkleTree::new(MockMerkleTree::tree_height());
1509
1510        // Insert entries and merkle nodes for each block height.
1511        let mut tx = storage.write().await.unwrap();
1512        for height in 0..num_blocks {
1513            test_tree.update(height as usize, height as usize).unwrap();
1514
1515            let test_data = serde_json::json!({
1516                MockMerkleTree::header_state_commitment_field():
1517                    serde_json::to_value(test_tree.commitment()).unwrap()
1518            });
1519            tx.upsert(
1520                "header",
1521                [
1522                    "height",
1523                    "hash",
1524                    "payload_hash",
1525                    "timestamp",
1526                    "data",
1527                    "ns_table",
1528                ],
1529                ["height"],
1530                [(
1531                    height as i64,
1532                    format!("hash{height}"),
1533                    "ph".to_string(),
1534                    0,
1535                    test_data,
1536                    "ns".to_string(),
1537                )],
1538            )
1539            .await
1540            .unwrap();
1541
1542            let (_, proof) = test_tree.lookup(height as usize).expect_ok().unwrap();
1543            let traversal_path = <usize as ToTraversalPath<8>>::to_traversal_path(
1544                &(height as usize),
1545                test_tree.height(),
1546            );
1547            UpdateStateData::<_, MockMerkleTree, 8>::insert_merkle_nodes(
1548                &mut tx,
1549                proof.clone(),
1550                traversal_path,
1551                height,
1552            )
1553            .await
1554            .unwrap();
1555        }
1556        UpdateStateData::<_, MockMerkleTree, 8>::set_last_state_height(
1557            &mut tx,
1558            num_blocks as usize,
1559        )
1560        .await
1561        .unwrap();
1562        tx.commit().await.unwrap();
1563
1564        // Prune up to height 500, keeping only the newest version of each node.
1565        let prune_height = 5678u64;
1566        let mut tx = storage.prune_write().await.unwrap();
1567        tx.delete_state_batch(vec!["test_tree".to_string()], prune_height)
1568            .await
1569            .unwrap();
1570        tx.commit().await.unwrap();
1571
1572        // Verify no paths have multiple versions at or below the prune height.
1573        let mut tx = storage.read().await.unwrap();
1574        let (duplicates,) = query_as::<(i64,)>(
1575            "SELECT count(*) FROM (SELECT count(*) FROM test_tree WHERE created <= $1 GROUP BY \
1576             path HAVING count(*) > 1) AS s",
1577        )
1578        .bind(prune_height as i64)
1579        .fetch_one(tx.as_mut())
1580        .await
1581        .unwrap();
1582        assert_eq!(
1583            duplicates, 0,
1584            "found {duplicates} paths with duplicate versions at or below prune height"
1585        );
1586
1587        // Verify get_path still works for the latest snapshot and returns correct proofs.
1588        let commitment = test_tree.commitment();
1589        let mut tx = storage.read().await.unwrap();
1590        for key in 0..num_blocks as usize {
1591            let proof = MerklizedStateStorage::<MockTypes, MockMerkleTree, 8>::get_path(
1592                &mut tx,
1593                Snapshot::Index(num_blocks - 1),
1594                key,
1595            )
1596            .await
1597            .unwrap_or_else(|e| panic!("get_path failed for key {key} after pruning: {e:#}"));
1598            assert_eq!(
1599                proof.elem(),
1600                Some(&key),
1601                "proof for key {key} has wrong element: {:?}",
1602                proof.elem()
1603            );
1604            MockMerkleTree::verify(commitment, key, &proof)
1605                .unwrap()
1606                .unwrap();
1607        }
1608    }
1609
1610    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1611    async fn test_minimum_retention_pruning() {
1612        let db = TmpDb::init().await;
1613
1614        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1615            .await
1616            .unwrap();
1617        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1618            &TestValidatedState::default(),
1619            &TestInstanceState::default(),
1620            TEST_VERSIONS.test,
1621        )
1622        .await;
1623        // insert some mock data
1624        for i in 0..20 {
1625            leaf.leaf.block_header_mut().block_number = i;
1626            leaf.leaf.block_header_mut().timestamp = Utc::now().timestamp() as u64;
1627            let mut tx = storage.write().await.unwrap();
1628            tx.insert_leaf(&leaf).await.unwrap();
1629            tx.commit().await.unwrap();
1630        }
1631
1632        let height_before_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1633        let cfg = PrunerCfg::new();
1634        // Set pruning_threshold to 1
1635        // SQL storage size is more than 1000 bytes even without any data indexed
1636        // This would mean that the threshold would always be greater than the disk usage
1637        // However, minimum retention is set to 24 hours by default so the data would not be pruned
1638        storage.set_pruning_config(cfg.clone().with_pruning_threshold(1));
1639        println!("{:?}", storage.get_pruning_config().unwrap());
1640        // Pruning would not delete any data
1641        // All the data is younger than minimum retention period even though the usage > threshold
1642        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1643        // Vacuum the database to reclaim space.
1644        // This is necessary to ensure the test passes.
1645        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1646        vacuum(&storage).await;
1647
1648        // Pruned height should be none
1649        assert!(pruned_height.is_none());
1650
1651        let height_after_pruning = storage.get_minimum_height().await.unwrap().unwrap();
1652
1653        assert_eq!(
1654            height_after_pruning, height_before_pruning,
1655            "some data has been pruned"
1656        );
1657
1658        // Change minimum retention to 1s
1659        storage.set_pruning_config(
1660            cfg.with_minimum_retention(Duration::from_secs(1))
1661                .with_pruning_threshold(1),
1662        );
1663        // sleep for 2s to make sure the data is older than minimum retention
1664        sleep(Duration::from_secs(2)).await;
1665        // This would prune all the data
1666        let pruned_height = storage.prune(&mut Default::default()).await.unwrap();
1667        // Vacuum the database to reclaim space.
1668        // This is necessary to ensure the test passes.
1669        // Note: We don't perform a vacuum after each pruner run in production because the auto vacuum job handles it automatically.
1670        vacuum(&storage).await;
1671
1672        // Pruned height should be some
1673        assert!(pruned_height.is_some());
1674        // All the tables should be empty
1675        // counting rows in header table
1676        let header_rows = storage
1677            .read()
1678            .await
1679            .unwrap()
1680            .fetch_one("select count(*) as count from header")
1681            .await
1682            .unwrap()
1683            .get::<i64, _>("count");
1684        // the table should be empty
1685        assert_eq!(header_rows, 0);
1686    }
1687
1688    #[tokio::test]
1689    #[test_log::test]
1690    async fn test_payload_pruning() {
1691        let db = TmpDb::init().await;
1692        let mut storage = SqlStorage::connect(db.config(), StorageConnectionType::Query)
1693            .await
1694            .unwrap();
1695        storage.set_pruning_config(Default::default());
1696
1697        // Insert some mock data.
1698        let mut leaf = LeafQueryData::<MockTypes>::genesis(
1699            &TestValidatedState::default(),
1700            &TestInstanceState::default(),
1701            TEST_VERSIONS.test,
1702        )
1703        .await;
1704        let block = BlockQueryData::<MockTypes>::genesis(
1705            &Default::default(),
1706            &Default::default(),
1707            TEST_VERSIONS.test.base,
1708        )
1709        .await;
1710        let vid = VidCommonQueryData::<MockTypes>::genesis(
1711            &Default::default(),
1712            &Default::default(),
1713            TEST_VERSIONS.test.base,
1714        )
1715        .await;
1716        {
1717            let mut tx = storage.write().await.unwrap();
1718            tx.insert_leaf(&leaf).await.unwrap();
1719            tx.insert_block(&block).await.unwrap();
1720            tx.insert_vid(&vid, None).await.unwrap();
1721            tx.commit().await.unwrap();
1722        }
1723
1724        // Insert a second leaf sharing the same payload.
1725        leaf.leaf.block_header_mut().block_number += 1;
1726        {
1727            let mut tx = storage.write().await.unwrap();
1728            tx.insert_leaf(&leaf).await.unwrap();
1729            tx.commit().await.unwrap();
1730        }
1731        {
1732            let mut tx = storage.read().await.unwrap();
1733            let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1734                .fetch_one(tx.as_mut())
1735                .await
1736                .unwrap();
1737            assert_eq!(num_payloads, 1);
1738            let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1739                .fetch_one(tx.as_mut())
1740                .await
1741                .unwrap();
1742            assert_eq!(num_vid, 1);
1743        }
1744
1745        // Prune the first leaf but not the second (and thus not the payload or VID).
1746        let pruned_height = storage
1747            .prune(&mut Pruner {
1748                pruned_height: None,
1749                target_height: Some(0),
1750                minimum_retention_height: None,
1751            })
1752            .await
1753            .unwrap();
1754        tracing::info!(?pruned_height, "first pruning run complete");
1755        {
1756            let mut tx = storage.read().await.unwrap();
1757
1758            // First block is pruned.
1759            let err = tx
1760                .get_block(BlockId::<MockTypes>::Number(0))
1761                .await
1762                .unwrap_err();
1763            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1764            let err = tx
1765                .get_vid_common(BlockId::<MockTypes>::Number(0))
1766                .await
1767                .unwrap_err();
1768            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1769
1770            // Second block is still available.
1771            assert_eq!(
1772                tx.get_block(BlockId::<MockTypes>::Number(1)).await.unwrap(),
1773                BlockQueryData::new(leaf.header().clone(), block.payload)
1774            );
1775            assert_eq!(
1776                tx.get_vid_common(BlockId::<MockTypes>::Number(1))
1777                    .await
1778                    .unwrap(),
1779                VidCommonQueryData::new(leaf.header().clone(), vid.common)
1780            );
1781
1782            let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1783                .fetch_one(tx.as_mut())
1784                .await
1785                .unwrap();
1786            assert_eq!(num_payloads, 1);
1787
1788            let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1789                .fetch_one(tx.as_mut())
1790                .await
1791                .unwrap();
1792            assert_eq!(num_vid, 1);
1793        }
1794
1795        // Now prune the second leaf, ensuring the payload and VID get deleted as well.
1796        let pruned_height = storage
1797            .prune(&mut Pruner {
1798                pruned_height,
1799                target_height: Some(1),
1800                minimum_retention_height: None,
1801            })
1802            .await
1803            .unwrap();
1804        tracing::info!(?pruned_height, "second pruning run complete");
1805
1806        let mut tx = storage.read().await.unwrap();
1807        for i in 0..2 {
1808            let err = tx
1809                .get_block(BlockId::<MockTypes>::Number(i))
1810                .await
1811                .unwrap_err();
1812            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1813
1814            let err = tx
1815                .get_vid_common(BlockId::<MockTypes>::Number(i))
1816                .await
1817                .unwrap_err();
1818            assert!(matches!(err, QueryError::NotFound), "{err:#}");
1819        }
1820        let (num_payloads,): (i64,) = query_as("SELECT count(*) FROM payload")
1821            .fetch_one(tx.as_mut())
1822            .await
1823            .unwrap();
1824        assert_eq!(num_payloads, 0);
1825
1826        let (num_vid,): (i64,) = query_as("SELECT count(*) FROM vid_common")
1827            .fetch_one(tx.as_mut())
1828            .await
1829            .unwrap();
1830        assert_eq!(num_vid, 0);
1831    }
1832
1833    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1834    async fn test_pruned_height_storage() {
1835        let db = TmpDb::init().await;
1836        let cfg = db.config();
1837
1838        let storage = SqlStorage::connect(cfg, StorageConnectionType::Query)
1839            .await
1840            .unwrap();
1841        assert!(
1842            storage
1843                .read()
1844                .await
1845                .unwrap()
1846                .load_pruned_height()
1847                .await
1848                .unwrap()
1849                .is_none()
1850        );
1851        for height in [10, 20, 30] {
1852            let mut tx = storage.write().await.unwrap();
1853            tx.save_pruned_height(height).await.unwrap();
1854            tx.commit().await.unwrap();
1855            assert_eq!(
1856                storage
1857                    .read()
1858                    .await
1859                    .unwrap()
1860                    .load_pruned_height()
1861                    .await
1862                    .unwrap(),
1863                Some(height)
1864            );
1865        }
1866    }
1867
1868    #[test_log::test(tokio::test(flavor = "multi_thread"))]
1869    async fn test_transaction_upsert_retries() {
1870        let db = TmpDb::init().await;
1871        let config = db.config();
1872
1873        let storage = SqlStorage::connect(config, StorageConnectionType::Query)
1874            .await
1875            .unwrap();
1876
1877        let mut tx = storage.write().await.unwrap();
1878
1879        // Try to upsert into a table that does not exist.
1880        // This will fail, so our `upsert` function will enter the retry loop.
1881        // Since the table does not exist, all retries will eventually
1882        // fail and we expect an error to be returned.
1883        //
1884        // Previously, this case would cause  a panic because we were calling
1885        // methods on `QueryBuilder` after `.build()` without first
1886        // calling `.reset()`and according to the sqlx docs, that always panics.
1887        // Now, since we are properly calling `.reset()` inside `upsert()` for
1888        // the query builder, the function returns an error instead of panicking.
1889        tx.upsert("does_not_exist", ["test"], ["test"], [(1_i64,)])
1890            .await
1891            .unwrap_err();
1892    }
1893}