1#![cfg(any(test, feature = "testing"))]
14
15use std::{ops::RangeBounds, sync::Arc};
16
17use async_lock::Mutex;
18use async_trait::async_trait;
19use futures::future::Future;
20use hotshot_types::{
21 data::VidShare, simple_certificate::CertificatePair, traits::node_implementation::NodeType,
22};
23
24use super::{
25 Aggregate, AggregatesStorage, AvailabilityStorage, NodeStorage, UpdateAggregatesStorage,
26 UpdateAvailabilityStorage,
27 pruning::{PruneStorage, PrunedHeightStorage, PrunerCfg, PrunerConfig},
28};
29use crate::{
30 Header, Payload, QueryError, QueryResult,
31 availability::{
32 BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceId, PayloadQueryData,
33 QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
34 },
35 data_source::{
36 VersionedDataSource,
37 storage::{PayloadMetadata, VidCommonMetadata},
38 update,
39 },
40 metrics::PrometheusMetrics,
41 node::{SyncStatusQueryData, TimeWindowQueryData, WindowStart},
42 status::HasMetrics,
43};
44
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum FailableAction {
48 GetHeader,
51 GetLeaf,
52 GetBlock,
53 GetPayload,
54 GetPayloadMetadata,
55 GetVidCommon,
56 GetVidCommonMetadata,
57 GetHeaderRange,
58 GetLeafRange,
59 GetBlockRange,
60 GetPayloadRange,
61 GetPayloadMetadataRange,
62 GetVidCommonRange,
63 GetVidCommonMetadataRange,
64 GetTransaction,
65 FirstAvailableLeaf,
66 GetStateCert,
67
68 Any,
70}
71
72impl FailableAction {
73 fn matches(self, action: Self) -> bool {
75 self == action || self == Self::Any
78 }
79}
80
81#[derive(Clone, Copy, Debug, Default)]
82enum FailureMode {
83 #[default]
84 Never,
85 Once(FailableAction),
86 Always(FailableAction),
87}
88
89impl FailureMode {
90 fn maybe_fail(&mut self, action: FailableAction) -> QueryResult<()> {
91 match self {
92 Self::Once(fail_action) if fail_action.matches(action) => {
93 *self = Self::Never;
94 },
95 Self::Always(fail_action) if fail_action.matches(action) => {},
96 _ => return Ok(()),
97 }
98
99 Err(QueryError::Error {
100 message: "injected error".into(),
101 })
102 }
103}
104
105#[derive(Debug, Default)]
106struct Failure {
107 on_read: FailureMode,
108 on_write: FailureMode,
109 on_commit: FailureMode,
110 on_begin_writable: FailureMode,
111 on_begin_read_only: FailureMode,
112}
113
114#[derive(Clone, Debug)]
116pub struct FailStorage<S> {
117 inner: S,
118 failure: Arc<Mutex<Failure>>,
119}
120
121impl<S> From<S> for FailStorage<S> {
122 fn from(inner: S) -> Self {
123 Self {
124 inner,
125 failure: Default::default(),
126 }
127 }
128}
129
130impl<S> FailStorage<S> {
131 pub async fn fail_reads(&self, action: FailableAction) {
132 self.failure.lock().await.on_read = FailureMode::Always(action);
133 }
134
135 pub async fn fail_writes(&self, action: FailableAction) {
136 self.failure.lock().await.on_write = FailureMode::Always(action);
137 }
138
139 pub async fn fail_commits(&self, action: FailableAction) {
140 self.failure.lock().await.on_commit = FailureMode::Always(action);
141 }
142
143 pub async fn fail_begins_writable(&self, action: FailableAction) {
144 self.failure.lock().await.on_begin_writable = FailureMode::Always(action);
145 }
146
147 pub async fn fail_begins_read_only(&self, action: FailableAction) {
148 self.failure.lock().await.on_begin_read_only = FailureMode::Always(action);
149 }
150
151 pub async fn fail(&self, action: FailableAction) {
152 let mut failure = self.failure.lock().await;
153 failure.on_read = FailureMode::Always(action);
154 failure.on_write = FailureMode::Always(action);
155 failure.on_commit = FailureMode::Always(action);
156 failure.on_begin_writable = FailureMode::Always(action);
157 failure.on_begin_read_only = FailureMode::Always(action);
158 }
159
160 pub async fn pass_reads(&self) {
161 self.failure.lock().await.on_read = FailureMode::Never;
162 }
163
164 pub async fn pass_writes(&self) {
165 self.failure.lock().await.on_write = FailureMode::Never;
166 }
167
168 pub async fn pass_commits(&self) {
169 self.failure.lock().await.on_commit = FailureMode::Never;
170 }
171
172 pub async fn pass_begins_writable(&self) {
173 self.failure.lock().await.on_begin_writable = FailureMode::Never;
174 }
175
176 pub async fn pass_begins_read_only(&self) {
177 self.failure.lock().await.on_begin_read_only = FailureMode::Never;
178 }
179
180 pub async fn pass(&self) {
181 let mut failure = self.failure.lock().await;
182 failure.on_read = FailureMode::Never;
183 failure.on_write = FailureMode::Never;
184 failure.on_commit = FailureMode::Never;
185 failure.on_begin_writable = FailureMode::Never;
186 failure.on_begin_read_only = FailureMode::Never;
187 }
188
189 pub async fn fail_one_read(&self, action: FailableAction) {
190 self.failure.lock().await.on_read = FailureMode::Once(action);
191 }
192
193 pub async fn fail_one_write(&self, action: FailableAction) {
194 self.failure.lock().await.on_write = FailureMode::Once(action);
195 }
196
197 pub async fn fail_one_commit(&self, action: FailableAction) {
198 self.failure.lock().await.on_commit = FailureMode::Once(action);
199 }
200
201 pub async fn fail_one_begin_writable(&self, action: FailableAction) {
202 self.failure.lock().await.on_begin_writable = FailureMode::Once(action);
203 }
204
205 pub async fn fail_one_begin_read_only(&self, action: FailableAction) {
206 self.failure.lock().await.on_begin_read_only = FailureMode::Once(action);
207 }
208}
209
210impl<S> VersionedDataSource for FailStorage<S>
211where
212 S: VersionedDataSource,
213{
214 type Transaction<'a>
215 = Transaction<S::Transaction<'a>>
216 where
217 Self: 'a;
218 type ReadOnly<'a>
219 = Transaction<S::ReadOnly<'a>>
220 where
221 Self: 'a;
222
223 async fn write(&self) -> anyhow::Result<<Self as VersionedDataSource>::Transaction<'_>> {
224 self.failure
225 .lock()
226 .await
227 .on_begin_writable
228 .maybe_fail(FailableAction::Any)?;
229 Ok(Transaction {
230 inner: self.inner.write().await?,
231 failure: self.failure.clone(),
232 })
233 }
234
235 async fn read(&self) -> anyhow::Result<<Self as VersionedDataSource>::ReadOnly<'_>> {
236 self.failure
237 .lock()
238 .await
239 .on_begin_read_only
240 .maybe_fail(FailableAction::Any)?;
241 Ok(Transaction {
242 inner: self.inner.read().await?,
243 failure: self.failure.clone(),
244 })
245 }
246}
247
248impl<S> PrunerConfig for FailStorage<S>
249where
250 S: PrunerConfig,
251{
252 fn set_pruning_config(&mut self, cfg: PrunerCfg) {
253 self.inner.set_pruning_config(cfg);
254 }
255
256 fn get_pruning_config(&self) -> Option<PrunerCfg> {
257 self.inner.get_pruning_config()
258 }
259}
260
261#[async_trait]
262impl<S> PruneStorage for FailStorage<S>
263where
264 S: PruneStorage + Sync,
265{
266 type Pruner = S::Pruner;
267
268 async fn get_disk_usage(&self) -> anyhow::Result<u64> {
269 self.inner.get_disk_usage().await
270 }
271
272 async fn prune(&self, pruner: &mut Self::Pruner) -> anyhow::Result<Option<u64>> {
273 self.inner.prune(pruner).await
274 }
275}
276
277impl<S> HasMetrics for FailStorage<S>
278where
279 S: HasMetrics,
280{
281 fn metrics(&self) -> &PrometheusMetrics {
282 self.inner.metrics()
283 }
284}
285
286#[derive(Debug)]
287pub struct Transaction<T> {
288 inner: T,
289 failure: Arc<Mutex<Failure>>,
290}
291
292impl<T> Transaction<T> {
293 async fn maybe_fail_read(&self, action: FailableAction) -> QueryResult<()> {
294 self.failure.lock().await.on_read.maybe_fail(action)
295 }
296
297 async fn maybe_fail_write(&self, action: FailableAction) -> QueryResult<()> {
298 self.failure.lock().await.on_write.maybe_fail(action)
299 }
300
301 async fn maybe_fail_commit(&self, action: FailableAction) -> QueryResult<()> {
302 self.failure.lock().await.on_commit.maybe_fail(action)
303 }
304}
305
306impl<T> update::Transaction for Transaction<T>
307where
308 T: update::Transaction,
309{
310 async fn commit(self) -> anyhow::Result<()> {
311 self.maybe_fail_commit(FailableAction::Any).await?;
312 self.inner.commit().await
313 }
314
315 fn revert(self) -> impl Future + Send {
316 self.inner.revert()
317 }
318}
319
320#[async_trait]
321impl<Types, T> AvailabilityStorage<Types> for Transaction<T>
322where
323 Types: NodeType,
324 Header<Types>: QueryableHeader<Types>,
325 Payload<Types>: QueryablePayload<Types>,
326 T: AvailabilityStorage<Types>,
327{
328 async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
329 self.maybe_fail_read(FailableAction::GetLeaf).await?;
330 self.inner.get_leaf(id).await
331 }
332
333 async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
334 self.maybe_fail_read(FailableAction::GetBlock).await?;
335 self.inner.get_block(id).await
336 }
337
338 async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
339 self.maybe_fail_read(FailableAction::GetHeader).await?;
340 self.inner.get_header(id).await
341 }
342
343 async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
344 self.maybe_fail_read(FailableAction::GetPayload).await?;
345 self.inner.get_payload(id).await
346 }
347
348 async fn get_payload_metadata(
349 &mut self,
350 id: BlockId<Types>,
351 ) -> QueryResult<PayloadMetadata<Types>> {
352 self.maybe_fail_read(FailableAction::GetPayloadMetadata)
353 .await?;
354 self.inner.get_payload_metadata(id).await
355 }
356
357 async fn get_vid_common(
358 &mut self,
359 id: BlockId<Types>,
360 ) -> QueryResult<VidCommonQueryData<Types>> {
361 self.maybe_fail_read(FailableAction::GetVidCommon).await?;
362 self.inner.get_vid_common(id).await
363 }
364
365 async fn get_vid_common_metadata(
366 &mut self,
367 id: BlockId<Types>,
368 ) -> QueryResult<VidCommonMetadata<Types>> {
369 self.maybe_fail_read(FailableAction::GetVidCommonMetadata)
370 .await?;
371 self.inner.get_vid_common_metadata(id).await
372 }
373
374 async fn get_leaf_range<R>(
375 &mut self,
376 range: R,
377 ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
378 where
379 R: RangeBounds<usize> + Send + 'static,
380 {
381 self.maybe_fail_read(FailableAction::GetLeafRange).await?;
382 self.inner.get_leaf_range(range).await
383 }
384
385 async fn get_block_range<R>(
386 &mut self,
387 range: R,
388 ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
389 where
390 R: RangeBounds<usize> + Send + 'static,
391 {
392 self.maybe_fail_read(FailableAction::GetBlockRange).await?;
393 self.inner.get_block_range(range).await
394 }
395
396 async fn get_payload_range<R>(
397 &mut self,
398 range: R,
399 ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
400 where
401 R: RangeBounds<usize> + Send + 'static,
402 {
403 self.maybe_fail_read(FailableAction::GetPayloadRange)
404 .await?;
405 self.inner.get_payload_range(range).await
406 }
407
408 async fn get_payload_metadata_range<R>(
409 &mut self,
410 range: R,
411 ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
412 where
413 R: RangeBounds<usize> + Send + 'static,
414 {
415 self.maybe_fail_read(FailableAction::GetPayloadMetadataRange)
416 .await?;
417 self.inner.get_payload_metadata_range(range).await
418 }
419
420 async fn get_vid_common_range<R>(
421 &mut self,
422 range: R,
423 ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
424 where
425 R: RangeBounds<usize> + Send + 'static,
426 {
427 self.maybe_fail_read(FailableAction::GetVidCommonRange)
428 .await?;
429 self.inner.get_vid_common_range(range).await
430 }
431
432 async fn get_vid_common_metadata_range<R>(
433 &mut self,
434 range: R,
435 ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
436 where
437 R: RangeBounds<usize> + Send + 'static,
438 {
439 self.maybe_fail_read(FailableAction::GetVidCommonMetadataRange)
440 .await?;
441 self.inner.get_vid_common_metadata_range(range).await
442 }
443
444 async fn get_block_with_transaction(
445 &mut self,
446 hash: TransactionHash<Types>,
447 ) -> QueryResult<BlockQueryData<Types>> {
448 self.maybe_fail_read(FailableAction::GetTransaction).await?;
449 self.inner.get_block_with_transaction(hash).await
450 }
451
452 async fn first_available_leaf(&mut self, from: u64) -> QueryResult<LeafQueryData<Types>> {
453 self.maybe_fail_read(FailableAction::FirstAvailableLeaf)
454 .await?;
455 self.inner.first_available_leaf(from).await
456 }
457}
458
459impl<Types, T> UpdateAvailabilityStorage<Types> for Transaction<T>
460where
461 Types: NodeType,
462 Header<Types>: QueryableHeader<Types>,
463 Payload<Types>: QueryablePayload<Types>,
464 T: UpdateAvailabilityStorage<Types> + Send + Sync,
465{
466 async fn insert_qc_chain(
467 &mut self,
468 height: u64,
469 qc_chain: Option<[CertificatePair<Types>; 2]>,
470 ) -> anyhow::Result<()> {
471 self.maybe_fail_write(FailableAction::Any).await?;
472 self.inner.insert_qc_chain(height, qc_chain).await
473 }
474
475 async fn insert_leaf_range<'a>(
476 &mut self,
477 leaves: impl Send + IntoIterator<IntoIter: Send, Item = &'a LeafQueryData<Types>>,
478 ) -> anyhow::Result<()> {
479 self.maybe_fail_write(FailableAction::Any).await?;
480 self.inner.insert_leaf_range(leaves).await
481 }
482
483 async fn insert_block_range<'a>(
484 &mut self,
485 blocks: impl Send + IntoIterator<IntoIter: Send, Item = &'a BlockQueryData<Types>>,
486 ) -> anyhow::Result<()> {
487 self.maybe_fail_write(FailableAction::Any).await?;
488 self.inner.insert_block_range(blocks).await
489 }
490
491 async fn insert_vid_range<'a>(
492 &mut self,
493 vid: impl Send
494 + IntoIterator<
495 IntoIter: Send,
496 Item = (&'a VidCommonQueryData<Types>, Option<&'a VidShare>),
497 >,
498 ) -> anyhow::Result<()> {
499 self.maybe_fail_write(FailableAction::Any).await?;
500 self.inner.insert_vid_range(vid).await
501 }
502}
503
504#[async_trait]
505impl<T> PrunedHeightStorage for Transaction<T>
506where
507 T: PrunedHeightStorage + Send + Sync,
508{
509 async fn load_pruned_height(&mut self) -> anyhow::Result<Option<u64>> {
510 self.maybe_fail_read(FailableAction::Any).await?;
511 self.inner.load_pruned_height().await
512 }
513}
514
515#[async_trait]
516impl<Types, T> NodeStorage<Types> for Transaction<T>
517where
518 Types: NodeType,
519 Header<Types>: QueryableHeader<Types>,
520 T: NodeStorage<Types> + Send + Sync,
521{
522 async fn block_height(&mut self) -> QueryResult<usize> {
523 self.maybe_fail_read(FailableAction::Any).await?;
524 self.inner.block_height().await
525 }
526
527 async fn count_transactions_in_range(
528 &mut self,
529 range: impl RangeBounds<usize> + Send,
530 namespace: Option<NamespaceId<Types>>,
531 ) -> QueryResult<usize> {
532 self.maybe_fail_read(FailableAction::Any).await?;
533 self.inner
534 .count_transactions_in_range(range, namespace)
535 .await
536 }
537
538 async fn payload_size_in_range(
539 &mut self,
540 range: impl RangeBounds<usize> + Send,
541 namespace: Option<NamespaceId<Types>>,
542 ) -> QueryResult<usize> {
543 self.maybe_fail_read(FailableAction::Any).await?;
544 self.inner.payload_size_in_range(range, namespace).await
545 }
546
547 async fn vid_share<ID>(&mut self, id: ID) -> QueryResult<VidShare>
548 where
549 ID: Into<BlockId<Types>> + Send + Sync,
550 {
551 self.maybe_fail_read(FailableAction::Any).await?;
552 self.inner.vid_share(id).await
553 }
554
555 async fn sync_status_for_range(
556 &mut self,
557 start: usize,
558 end: usize,
559 ) -> QueryResult<SyncStatusQueryData> {
560 self.maybe_fail_read(FailableAction::Any).await?;
561 self.inner.sync_status_for_range(start, end).await
562 }
563
564 async fn get_header_window(
565 &mut self,
566 start: impl Into<WindowStart<Types>> + Send + Sync,
567 end: u64,
568 limit: usize,
569 ) -> QueryResult<TimeWindowQueryData<Header<Types>>> {
570 self.maybe_fail_read(FailableAction::Any).await?;
571 self.inner.get_header_window(start, end, limit).await
572 }
573
574 async fn latest_qc_chain(&mut self) -> QueryResult<Option<[CertificatePair<Types>; 2]>> {
575 self.maybe_fail_read(FailableAction::Any).await?;
576 self.inner.latest_qc_chain().await
577 }
578}
579
580impl<Types, T> AggregatesStorage<Types> for Transaction<T>
581where
582 Types: NodeType,
583 Header<Types>: QueryableHeader<Types>,
584 T: AggregatesStorage<Types> + Send + Sync,
585{
586 async fn aggregates_height(&mut self) -> anyhow::Result<usize> {
587 self.maybe_fail_read(FailableAction::Any).await?;
588 self.inner.aggregates_height().await
589 }
590
591 async fn load_prev_aggregate(&mut self) -> anyhow::Result<Option<Aggregate<Types>>> {
592 self.maybe_fail_read(FailableAction::Any).await?;
593 self.inner.load_prev_aggregate().await
594 }
595}
596
597impl<T, Types> UpdateAggregatesStorage<Types> for Transaction<T>
598where
599 Types: NodeType,
600 Header<Types>: QueryableHeader<Types>,
601 T: UpdateAggregatesStorage<Types> + Send + Sync,
602{
603 async fn update_aggregates(
604 &mut self,
605 prev: Aggregate<Types>,
606 blocks: &[PayloadMetadata<Types>],
607 ) -> anyhow::Result<Aggregate<Types>> {
608 self.maybe_fail_write(FailableAction::Any).await?;
609 self.inner.update_aggregates(prev, blocks).await
610 }
611}