hotshot_query_service/data_source/storage/
fail_storage.rs

1// Copyright (c) 2022 Espresso Systems (espressosys.com)
2// This file is part of the HotShot Query Service library.
3//
4// This program is free software: you can redistribute it and/or modify it under the terms of the GNU
5// General Public License as published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7// This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
8// even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
9// General Public License for more details.
10// You should have received a copy of the GNU General Public License along with this program. If not,
11// see <https://www.gnu.org/licenses/>.
12
13#![cfg(any(test, feature = "testing"))]
14
15use std::{ops::RangeBounds, sync::Arc};
16
17use async_lock::Mutex;
18use async_trait::async_trait;
19use futures::future::Future;
20use hotshot_types::{
21    data::VidShare, simple_certificate::CertificatePair, traits::node_implementation::NodeType,
22};
23
24use super::{
25    Aggregate, AggregatesStorage, AvailabilityStorage, NodeStorage, UpdateAggregatesStorage,
26    UpdateAvailabilityStorage,
27    pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg, PrunerConfig},
28};
29use crate::{
30    Header, Payload, QueryError, QueryResult,
31    availability::{
32        BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceId, PayloadQueryData,
33        QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
34    },
35    data_source::{
36        VersionedDataSource,
37        storage::{PayloadMetadata, VidCommonMetadata},
38        update,
39    },
40    metrics::PrometheusMetrics,
41    node::{SyncStatusQueryData, TimeWindowQueryData, WindowStart},
42    status::HasMetrics,
43};
44
45/// A specific action that can be targeted to inject an error.
46#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum FailableAction {
48    // TODO currently we implement failable actions for the availability methods, but if needed we
49    // can always add more variants for other actions.
50    GetHeader,
51    GetLeaf,
52    GetBlock,
53    GetPayload,
54    GetPayloadMetadata,
55    GetVidCommon,
56    GetVidCommonMetadata,
57    GetHeaderRange,
58    GetLeafRange,
59    GetBlockRange,
60    GetPayloadRange,
61    GetPayloadMetadataRange,
62    GetVidCommonRange,
63    GetVidCommonMetadataRange,
64    GetTransaction,
65    FirstAvailableLeaf,
66    GetStateCert,
67
68    /// Target any action for failure.
69    Any,
70}
71
72impl FailableAction {
73    /// Should `self` being targeted for failure cause `action` to fail?
74    fn matches(self, action: Self) -> bool {
75        // Fail if this is the action specifically targeted for failure or if we are failing any
76        // action right now.
77        self == action || self == Self::Any
78    }
79}
80
81#[derive(Clone, Copy, Debug, Default)]
82enum FailureMode {
83    #[default]
84    Never,
85    Once(FailableAction),
86    Always(FailableAction),
87}
88
89impl FailureMode {
90    fn maybe_fail(&mut self, action: FailableAction) -> QueryResult<()> {
91        match self {
92            Self::Once(fail_action) if fail_action.matches(action) => {
93                *self = Self::Never;
94            },
95            Self::Always(fail_action) if fail_action.matches(action) => {},
96            _ => return Ok(()),
97        }
98
99        Err(QueryError::Error {
100            message: "injected error".into(),
101        })
102    }
103}
104
105#[derive(Debug, Default)]
106struct Failure {
107    on_read: FailureMode,
108    on_write: FailureMode,
109    on_commit: FailureMode,
110    on_begin_writable: FailureMode,
111    on_begin_read_only: FailureMode,
112}
113
114/// Storage wrapper for error injection.
115#[derive(Clone, Debug)]
116pub struct FailStorage<S> {
117    inner: S,
118    failure: Arc<Mutex<Failure>>,
119}
120
121impl<S> From<S> for FailStorage<S> {
122    fn from(inner: S) -> Self {
123        Self {
124            inner,
125            failure: Default::default(),
126        }
127    }
128}
129
130impl<S> FailStorage<S> {
131    pub async fn fail_reads(&self, action: FailableAction) {
132        self.failure.lock().await.on_read = FailureMode::Always(action);
133    }
134
135    pub async fn fail_writes(&self, action: FailableAction) {
136        self.failure.lock().await.on_write = FailureMode::Always(action);
137    }
138
139    pub async fn fail_commits(&self, action: FailableAction) {
140        self.failure.lock().await.on_commit = FailureMode::Always(action);
141    }
142
143    pub async fn fail_begins_writable(&self, action: FailableAction) {
144        self.failure.lock().await.on_begin_writable = FailureMode::Always(action);
145    }
146
147    pub async fn fail_begins_read_only(&self, action: FailableAction) {
148        self.failure.lock().await.on_begin_read_only = FailureMode::Always(action);
149    }
150
151    pub async fn fail(&self, action: FailableAction) {
152        let mut failure = self.failure.lock().await;
153        failure.on_read = FailureMode::Always(action);
154        failure.on_write = FailureMode::Always(action);
155        failure.on_commit = FailureMode::Always(action);
156        failure.on_begin_writable = FailureMode::Always(action);
157        failure.on_begin_read_only = FailureMode::Always(action);
158    }
159
160    pub async fn pass_reads(&self) {
161        self.failure.lock().await.on_read = FailureMode::Never;
162    }
163
164    pub async fn pass_writes(&self) {
165        self.failure.lock().await.on_write = FailureMode::Never;
166    }
167
168    pub async fn pass_commits(&self) {
169        self.failure.lock().await.on_commit = FailureMode::Never;
170    }
171
172    pub async fn pass_begins_writable(&self) {
173        self.failure.lock().await.on_begin_writable = FailureMode::Never;
174    }
175
176    pub async fn pass_begins_read_only(&self) {
177        self.failure.lock().await.on_begin_read_only = FailureMode::Never;
178    }
179
180    pub async fn pass(&self) {
181        let mut failure = self.failure.lock().await;
182        failure.on_read = FailureMode::Never;
183        failure.on_write = FailureMode::Never;
184        failure.on_commit = FailureMode::Never;
185        failure.on_begin_writable = FailureMode::Never;
186        failure.on_begin_read_only = FailureMode::Never;
187    }
188
189    pub async fn fail_one_read(&self, action: FailableAction) {
190        self.failure.lock().await.on_read = FailureMode::Once(action);
191    }
192
193    pub async fn fail_one_write(&self, action: FailableAction) {
194        self.failure.lock().await.on_write = FailureMode::Once(action);
195    }
196
197    pub async fn fail_one_commit(&self, action: FailableAction) {
198        self.failure.lock().await.on_commit = FailureMode::Once(action);
199    }
200
201    pub async fn fail_one_begin_writable(&self, action: FailableAction) {
202        self.failure.lock().await.on_begin_writable = FailureMode::Once(action);
203    }
204
205    pub async fn fail_one_begin_read_only(&self, action: FailableAction) {
206        self.failure.lock().await.on_begin_read_only = FailureMode::Once(action);
207    }
208}
209
210impl<S> VersionedDataSource for FailStorage<S>
211where
212    S: VersionedDataSource,
213{
214    type Transaction<'a>
215        = Transaction<S::Transaction<'a>>
216    where
217        Self: 'a;
218    type ReadOnly<'a>
219        = Transaction<S::ReadOnly<'a>>
220    where
221        Self: 'a;
222
223    async fn write(&self) -> anyhow::Result<<Self as VersionedDataSource>::Transaction<'_>> {
224        self.failure
225            .lock()
226            .await
227            .on_begin_writable
228            .maybe_fail(FailableAction::Any)?;
229        Ok(Transaction {
230            inner: self.inner.write().await?,
231            failure: self.failure.clone(),
232        })
233    }
234
235    async fn read(&self) -> anyhow::Result<<Self as VersionedDataSource>::ReadOnly<'_>> {
236        self.failure
237            .lock()
238            .await
239            .on_begin_read_only
240            .maybe_fail(FailableAction::Any)?;
241        Ok(Transaction {
242            inner: self.inner.read().await?,
243            failure: self.failure.clone(),
244        })
245    }
246}
247
248impl<S> PrunerConfig for FailStorage<S>
249where
250    S: PrunerConfig,
251{
252    fn set_pruning_config(&mut self, cfg: PrunerCfg) {
253        self.inner.set_pruning_config(cfg);
254    }
255
256    fn get_pruning_config(&self) -> Option<PrunerCfg> {
257        self.inner.get_pruning_config()
258    }
259}
260
261#[async_trait]
262impl<S> PruneStorage for FailStorage<S>
263where
264    S: PruneStorage + Sync,
265{
266    type Pruner = S::Pruner;
267
268    async fn get_disk_usage(&self) -> anyhow::Result<u64> {
269        self.inner.get_disk_usage().await
270    }
271
272    async fn prune(&self, pruner: &mut Self::Pruner) -> anyhow::Result<Option<u64>> {
273        self.inner.prune(pruner).await
274    }
275}
276
277impl<S> HasMetrics for FailStorage<S>
278where
279    S: HasMetrics,
280{
281    fn metrics(&self) -> &PrometheusMetrics {
282        self.inner.metrics()
283    }
284}
285
286#[derive(Debug)]
287pub struct Transaction<T> {
288    inner: T,
289    failure: Arc<Mutex<Failure>>,
290}
291
292impl<T> Transaction<T> {
293    async fn maybe_fail_read(&self, action: FailableAction) -> QueryResult<()> {
294        self.failure.lock().await.on_read.maybe_fail(action)
295    }
296
297    async fn maybe_fail_write(&self, action: FailableAction) -> QueryResult<()> {
298        self.failure.lock().await.on_write.maybe_fail(action)
299    }
300
301    async fn maybe_fail_commit(&self, action: FailableAction) -> QueryResult<()> {
302        self.failure.lock().await.on_commit.maybe_fail(action)
303    }
304}
305
306impl<T> update::Transaction for Transaction<T>
307where
308    T: update::Transaction,
309{
310    async fn commit(self) -> anyhow::Result<()> {
311        self.maybe_fail_commit(FailableAction::Any).await?;
312        self.inner.commit().await
313    }
314
315    fn revert(self) -> impl Future + Send {
316        self.inner.revert()
317    }
318}
319
320#[async_trait]
321impl<Types, T> AvailabilityStorage<Types> for Transaction<T>
322where
323    Types: NodeType,
324    Header<Types>: QueryableHeader<Types>,
325    Payload<Types>: QueryablePayload<Types>,
326    T: AvailabilityStorage<Types>,
327{
328    async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
329        self.maybe_fail_read(FailableAction::GetLeaf).await?;
330        self.inner.get_leaf(id).await
331    }
332
333    async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
334        self.maybe_fail_read(FailableAction::GetBlock).await?;
335        self.inner.get_block(id).await
336    }
337
338    async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
339        self.maybe_fail_read(FailableAction::GetHeader).await?;
340        self.inner.get_header(id).await
341    }
342
343    async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
344        self.maybe_fail_read(FailableAction::GetPayload).await?;
345        self.inner.get_payload(id).await
346    }
347
348    async fn get_payload_metadata(
349        &mut self,
350        id: BlockId<Types>,
351    ) -> QueryResult<PayloadMetadata<Types>> {
352        self.maybe_fail_read(FailableAction::GetPayloadMetadata)
353            .await?;
354        self.inner.get_payload_metadata(id).await
355    }
356
357    async fn get_vid_common(
358        &mut self,
359        id: BlockId<Types>,
360    ) -> QueryResult<VidCommonQueryData<Types>> {
361        self.maybe_fail_read(FailableAction::GetVidCommon).await?;
362        self.inner.get_vid_common(id).await
363    }
364
365    async fn get_vid_common_metadata(
366        &mut self,
367        id: BlockId<Types>,
368    ) -> QueryResult<VidCommonMetadata<Types>> {
369        self.maybe_fail_read(FailableAction::GetVidCommonMetadata)
370            .await?;
371        self.inner.get_vid_common_metadata(id).await
372    }
373
374    async fn get_leaf_range<R>(
375        &mut self,
376        range: R,
377    ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
378    where
379        R: RangeBounds<usize> + Send + 'static,
380    {
381        self.maybe_fail_read(FailableAction::GetLeafRange).await?;
382        self.inner.get_leaf_range(range).await
383    }
384
385    async fn get_block_range<R>(
386        &mut self,
387        range: R,
388    ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
389    where
390        R: RangeBounds<usize> + Send + 'static,
391    {
392        self.maybe_fail_read(FailableAction::GetBlockRange).await?;
393        self.inner.get_block_range(range).await
394    }
395
396    async fn get_payload_range<R>(
397        &mut self,
398        range: R,
399    ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
400    where
401        R: RangeBounds<usize> + Send + 'static,
402    {
403        self.maybe_fail_read(FailableAction::GetPayloadRange)
404            .await?;
405        self.inner.get_payload_range(range).await
406    }
407
408    async fn get_payload_metadata_range<R>(
409        &mut self,
410        range: R,
411    ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
412    where
413        R: RangeBounds<usize> + Send + 'static,
414    {
415        self.maybe_fail_read(FailableAction::GetPayloadMetadataRange)
416            .await?;
417        self.inner.get_payload_metadata_range(range).await
418    }
419
420    async fn get_vid_common_range<R>(
421        &mut self,
422        range: R,
423    ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
424    where
425        R: RangeBounds<usize> + Send + 'static,
426    {
427        self.maybe_fail_read(FailableAction::GetVidCommonRange)
428            .await?;
429        self.inner.get_vid_common_range(range).await
430    }
431
432    async fn get_vid_common_metadata_range<R>(
433        &mut self,
434        range: R,
435    ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
436    where
437        R: RangeBounds<usize> + Send + 'static,
438    {
439        self.maybe_fail_read(FailableAction::GetVidCommonMetadataRange)
440            .await?;
441        self.inner.get_vid_common_metadata_range(range).await
442    }
443
444    async fn get_block_with_transaction(
445        &mut self,
446        hash: TransactionHash<Types>,
447    ) -> QueryResult<BlockQueryData<Types>> {
448        self.maybe_fail_read(FailableAction::GetTransaction).await?;
449        self.inner.get_block_with_transaction(hash).await
450    }
451
452    async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
453        self.maybe_fail_read(FailableAction::FirstAvailableLeaf)
454            .await?;
455        self.inner.first_available_leaf(from).await
456    }
457}
458
459impl<Types, T> UpdateAvailabilityStorage<Types> for Transaction<T>
460where
461    Types: NodeType,
462    Header<Types>: QueryableHeader<Types>,
463    Payload<Types>: QueryablePayload<Types>,
464    T: UpdateAvailabilityStorage<Types> + Send + Sync,
465{
466    async fn insert_qc_chain(
467        &mut self,
468        height: u64,
469        qc_chain: Option<[CertificatePair<Types>; 2]>,
470    ) -> anyhow::Result<()> {
471        self.maybe_fail_write(FailableAction::Any).await?;
472        self.inner.insert_qc_chain(height, qc_chain).await
473    }
474
475    async fn insert_leaf_range<'a>(
476        &mut self,
477        leaves: impl Send + IntoIterator<IntoIter: Send, Item = &'a LeafQueryData<Types>>,
478    ) -> anyhow::Result<()> {
479        self.maybe_fail_write(FailableAction::Any).await?;
480        self.inner.insert_leaf_range(leaves).await
481    }
482
483    async fn insert_block_range<'a>(
484        &mut self,
485        blocks: impl Send + IntoIterator<IntoIter: Send, Item = &'a BlockQueryData<Types>>,
486    ) -> anyhow::Result<()> {
487        self.maybe_fail_write(FailableAction::Any).await?;
488        self.inner.insert_block_range(blocks).await
489    }
490
491    async fn insert_vid_range<'a>(
492        &mut self,
493        vid: impl Send
494        + IntoIterator<
495            IntoIter: Send,
496            Item = (&'a VidCommonQueryData<Types>, Option<&'a VidShare>),
497        >,
498    ) -> anyhow::Result<()> {
499        self.maybe_fail_write(FailableAction::Any).await?;
500        self.inner.insert_vid_range(vid).await
501    }
502}
503
504#[async_trait]
505impl<T> PrunedHeightStorage for Transaction<T>
506where
507    T: PrunedHeightStorage + Send + Sync,
508{
509    async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
510        self.maybe_fail_read(FailableAction::Any).await?;
511        self.inner.load_pruned_height().await
512    }
513}
514
515#[async_trait]
516impl<Types, T> NodeStorage<Types> for Transaction<T>
517where
518    Types: NodeType,
519    Header<Types>: QueryableHeader<Types>,
520    T: NodeStorage<Types> + Send + Sync,
521{
522    async fn block_height(&mut self) -> QueryResult<usize> {
523        self.maybe_fail_read(FailableAction::Any).await?;
524        self.inner.block_height().await
525    }
526
527    async fn count_transactions_in_range(
528        &mut self,
529        range: impl RangeBounds<usize> + Send,
530        namespace: Option<NamespaceId<Types>>,
531    ) -> QueryResult<usize> {
532        self.maybe_fail_read(FailableAction::Any).await?;
533        self.inner
534            .count_transactions_in_range(range, namespace)
535            .await
536    }
537
538    async fn payload_size_in_range(
539        &mut self,
540        range: impl RangeBounds<usize> + Send,
541        namespace: Option<NamespaceId<Types>>,
542    ) -> QueryResult<usize> {
543        self.maybe_fail_read(FailableAction::Any).await?;
544        self.inner.payload_size_in_range(range, namespace).await
545    }
546
547    async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
548    where
549        ID: Into<BlockId<Types>> + Send + Sync,
550    {
551        self.maybe_fail_read(FailableAction::Any).await?;
552        self.inner.vid_share(id).await
553    }
554
555    async fn sync_status_for_range(
556        &mut self,
557        start: usize,
558        end: usize,
559    ) -> QueryResult<SyncStatusQueryData> {
560        self.maybe_fail_read(FailableAction::Any).await?;
561        self.inner.sync_status_for_range(start, end).await
562    }
563
564    async fn get_header_window(
565        &mut self,
566        start: impl Into<WindowStart<Types>> + Send + Sync,
567        end: u64,
568        limit: usize,
569    ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
570        self.maybe_fail_read(FailableAction::Any).await?;
571        self.inner.get_header_window(start, end, limit).await
572    }
573
574    async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
575        self.maybe_fail_read(FailableAction::Any).await?;
576        self.inner.latest_qc_chain().await
577    }
578}
579
580impl<Types, T> AggregatesStorage<Types> for Transaction<T>
581where
582    Types: NodeType,
583    Header<Types>: QueryableHeader<Types>,
584    T: AggregatesStorage<Types> + Send + Sync,
585{
586    async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
587        self.maybe_fail_read(FailableAction::Any).await?;
588        self.inner.aggregates_height().await
589    }
590
591    async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
592        self.maybe_fail_read(FailableAction::Any).await?;
593        self.inner.load_prev_aggregate().await
594    }
595}
596
597impl<T, Types> UpdateAggregatesStorage<Types> for Transaction<T>
598where
599    Types: NodeType,
600    Header<Types>: QueryableHeader<Types>,
601    T: UpdateAggregatesStorage<Types> + Send + Sync,
602{
603    async fn update_aggregates(
604        &mut self,
605        prev: Aggregate<Types>,
606        blocks: &[PayloadMetadata<Types>],
607    ) -> anyhow::Result<Aggregate<Types>> {
608        self.maybe_fail_write(FailableAction::Any).await?;
609        self.inner.update_aggregates(prev, blocks).await
610    }
611}