1use std::ops::RangeBounds;
16
17use async_trait::async_trait;
18use futures::stream::{StreamExt, TryStreamExt};
19use hotshot_types::traits::node_implementation::NodeType;
20use snafu::OptionExt;
21use sqlx::FromRow;
22
23use super::{
24 super::transaction::{Transaction, TransactionMode, query},
25 BLOCK_COLUMNS, LEAF_COLUMNS, PAYLOAD_COLUMNS, PAYLOAD_METADATA_COLUMNS, QueryBuilder,
26 VID_COMMON_COLUMNS, VID_COMMON_METADATA_COLUMNS,
27};
28use crate::{
29 Header, MissingSnafu, Payload, QueryError, QueryResult,
30 availability::{
31 BlockId, BlockQueryData, LeafId, LeafQueryData, NamespaceInfo, NamespaceMap,
32 PayloadQueryData, QueryableHeader, QueryablePayload, TransactionHash, VidCommonQueryData,
33 },
34 data_source::storage::{
35 AvailabilityStorage, PayloadMetadata, VidCommonMetadata, sql::sqlx::Row,
36 },
37 types::HeightIndexed,
38};
39
40#[async_trait]
41impl<Mode, Types> AvailabilityStorage<Types> for Transaction<Mode>
42where
43 Types: NodeType,
44 Mode: TransactionMode,
45 Payload<Types>: QueryablePayload<Types>,
46 Header<Types>: QueryableHeader<Types>,
47{
48 async fn get_leaf(&mut self, id: LeafId<Types>) -> QueryResult<LeafQueryData<Types>> {
49 let mut query = QueryBuilder::default();
50 let where_clause = match id {
51 LeafId::Number(n) => format!("height = {}", query.bind(n as i64)?),
52 LeafId::Hash(h) => format!("hash = {}", query.bind(h.to_string())?),
53 };
54 let row = query
55 .query(&format!(
56 "SELECT {LEAF_COLUMNS} FROM leaf2 WHERE {where_clause} LIMIT 1"
57 ))
58 .fetch_one(self.as_mut())
59 .await?;
60 let leaf = LeafQueryData::from_row(&row)?;
61 Ok(leaf)
62 }
63
64 async fn get_block(&mut self, id: BlockId<Types>) -> QueryResult<BlockQueryData<Types>> {
65 let mut query = QueryBuilder::default();
66 let where_clause = query.header_where_clause(id)?;
67 let sql = format!(
68 "SELECT {BLOCK_COLUMNS}
69 FROM header AS h
70 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
71 WHERE {where_clause}
72 LIMIT 1"
73 );
74 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
75 let block = BlockQueryData::from_row(&row)?;
76 Ok(block)
77 }
78
79 async fn get_header(&mut self, id: BlockId<Types>) -> QueryResult<Header<Types>> {
80 self.load_header(id).await
81 }
82
83 async fn get_payload(&mut self, id: BlockId<Types>) -> QueryResult<PayloadQueryData<Types>> {
84 let mut query = QueryBuilder::default();
85 let where_clause = query.header_where_clause(id)?;
86 let sql = format!(
87 "SELECT {PAYLOAD_COLUMNS}
88 FROM header AS h
89 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
90 WHERE {where_clause}
91 LIMIT 1"
92 );
93 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
94 let payload = PayloadQueryData::from_row(&row)?;
95 Ok(payload)
96 }
97
98 async fn get_payload_metadata(
99 &mut self,
100 id: BlockId<Types>,
101 ) -> QueryResult<PayloadMetadata<Types>> {
102 let mut query = QueryBuilder::default();
103 let where_clause = query.header_where_clause(id)?;
104 let sql = format!(
105 "SELECT {PAYLOAD_METADATA_COLUMNS}
106 FROM header AS h
107 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
108 WHERE {where_clause}
109 LIMIT 1"
110 );
111 let row = query
112 .query(&sql)
113 .fetch_optional(self.as_mut())
114 .await?
115 .context(MissingSnafu)?;
116 let mut payload = PayloadMetadata::from_row(&row)?;
117 payload.namespaces = self
118 .load_namespaces::<Types>(payload.height(), payload.size)
119 .await?;
120 Ok(payload)
121 }
122
123 async fn get_vid_common(
124 &mut self,
125 id: BlockId<Types>,
126 ) -> QueryResult<VidCommonQueryData<Types>> {
127 let mut query = QueryBuilder::default();
128 let where_clause = query.header_where_clause(id)?;
129 let sql = format!(
130 "SELECT {VID_COMMON_COLUMNS}
131 FROM header AS h
132 JOIN vid_common AS v ON h.payload_hash = v.hash
133 WHERE {where_clause}
134 LIMIT 1"
135 );
136 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
137 let common = VidCommonQueryData::from_row(&row)?;
138 Ok(common)
139 }
140
141 async fn get_vid_common_metadata(
142 &mut self,
143 id: BlockId<Types>,
144 ) -> QueryResult<VidCommonMetadata<Types>> {
145 let mut query = QueryBuilder::default();
146 let where_clause = query.header_where_clause(id)?;
147 let sql = format!(
148 "SELECT {VID_COMMON_METADATA_COLUMNS}
149 FROM header AS h
150 JOIN vid_common AS v ON h.payload_hash = v.hash
151 WHERE {where_clause}
152 LIMIT 1"
153 );
154 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
155 let common = VidCommonMetadata::from_row(&row)?;
156 Ok(common)
157 }
158
159 async fn get_leaf_range<R>(
160 &mut self,
161 range: R,
162 ) -> QueryResult<Vec<QueryResult<LeafQueryData<Types>>>>
163 where
164 R: RangeBounds<usize> + Send,
165 {
166 let mut query = QueryBuilder::default();
167 let where_clause = query.bounds_to_where_clause(range, "height")?;
168 let sql = format!("SELECT {LEAF_COLUMNS} FROM leaf2 {where_clause} ORDER BY height ASC");
169 Ok(query
170 .query(&sql)
171 .fetch(self.as_mut())
172 .map(|res| LeafQueryData::from_row(&res?))
173 .map_err(QueryError::from)
174 .collect()
175 .await)
176 }
177
178 async fn get_block_range<R>(
179 &mut self,
180 range: R,
181 ) -> QueryResult<Vec<QueryResult<BlockQueryData<Types>>>>
182 where
183 R: RangeBounds<usize> + Send,
184 {
185 let mut query = QueryBuilder::default();
186 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
187 let sql = format!(
188 "SELECT {BLOCK_COLUMNS}
189 FROM header AS h
190 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
191 {where_clause}
192 ORDER BY h.height"
193 );
194 Ok(query
195 .query(&sql)
196 .fetch(self.as_mut())
197 .map(|res| BlockQueryData::from_row(&res?))
198 .map_err(QueryError::from)
199 .collect()
200 .await)
201 }
202
203 async fn get_header_range<R>(
204 &mut self,
205 range: R,
206 ) -> QueryResult<Vec<QueryResult<Header<Types>>>>
207 where
208 R: RangeBounds<usize> + Send,
209 {
210 let mut query = QueryBuilder::default();
211 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
212
213 let headers = query
214 .query(&format!(
215 "SELECT data
216 FROM header AS h
217 {where_clause}
218 ORDER BY h.height"
219 ))
220 .fetch(self.as_mut())
221 .map(|res| serde_json::from_value(res?.get("data")).unwrap())
222 .collect()
223 .await;
224
225 Ok(headers)
226 }
227
228 async fn get_payload_range<R>(
229 &mut self,
230 range: R,
231 ) -> QueryResult<Vec<QueryResult<PayloadQueryData<Types>>>>
232 where
233 R: RangeBounds<usize> + Send,
234 {
235 let mut query = QueryBuilder::default();
236 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
237 let sql = format!(
238 "SELECT {PAYLOAD_COLUMNS}
239 FROM header AS h
240 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
241 {where_clause}
242 ORDER BY h.height"
243 );
244 Ok(query
245 .query(&sql)
246 .fetch(self.as_mut())
247 .map(|res| PayloadQueryData::from_row(&res?))
248 .map_err(QueryError::from)
249 .collect()
250 .await)
251 }
252
253 async fn get_payload_metadata_range<R>(
254 &mut self,
255 range: R,
256 ) -> QueryResult<Vec<QueryResult<PayloadMetadata<Types>>>>
257 where
258 R: RangeBounds<usize> + Send + 'static,
259 {
260 let mut query = QueryBuilder::default();
261 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
262 let sql = format!(
263 "SELECT {PAYLOAD_METADATA_COLUMNS}
264 FROM header AS h
265 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
266 {where_clause}
267 ORDER BY h.height ASC"
268 );
269 let rows = query
270 .query(&sql)
271 .fetch(self.as_mut())
272 .collect::<Vec<_>>()
273 .await;
274 let mut payloads = vec![];
275 for row in rows {
276 let res = async {
277 let mut meta = PayloadMetadata::from_row(&row?)?;
278 meta.namespaces = self
279 .load_namespaces::<Types>(meta.height(), meta.size)
280 .await?;
281 Ok(meta)
282 }
283 .await;
284 payloads.push(res);
285 }
286 Ok(payloads)
287 }
288
289 async fn get_vid_common_range<R>(
290 &mut self,
291 range: R,
292 ) -> QueryResult<Vec<QueryResult<VidCommonQueryData<Types>>>>
293 where
294 R: RangeBounds<usize> + Send,
295 {
296 let mut query = QueryBuilder::default();
297 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
298 let sql = format!(
299 "SELECT {VID_COMMON_COLUMNS}
300 FROM header AS h
301 JOIN vid_common AS v ON h.payload_hash = v.hash
302 {where_clause}
303 ORDER BY h.height"
304 );
305 Ok(query
306 .query(&sql)
307 .fetch(self.as_mut())
308 .map(|res| VidCommonQueryData::from_row(&res?))
309 .map_err(QueryError::from)
310 .collect()
311 .await)
312 }
313
314 async fn get_vid_common_metadata_range<R>(
315 &mut self,
316 range: R,
317 ) -> QueryResult<Vec<QueryResult<VidCommonMetadata<Types>>>>
318 where
319 R: RangeBounds<usize> + Send,
320 {
321 let mut query = QueryBuilder::default();
322 let where_clause = query.bounds_to_where_clause(range, "h.height")?;
323 let sql = format!(
324 "SELECT {VID_COMMON_METADATA_COLUMNS}
325 FROM header AS h
326 JOIN vid_common AS v ON h.payload_hash = v.hash
327 {where_clause}
328 ORDER BY h.height ASC"
329 );
330 Ok(query
331 .query(&sql)
332 .fetch(self.as_mut())
333 .map(|res| VidCommonMetadata::from_row(&res?))
334 .map_err(QueryError::from)
335 .collect()
336 .await)
337 }
338
339 async fn get_block_with_transaction(
340 &mut self,
341 hash: TransactionHash<Types>,
342 ) -> QueryResult<BlockQueryData<Types>> {
343 let mut query = QueryBuilder::default();
344 let hash_param = query.bind(hash.to_string())?;
345
346 let sql = format!(
349 "SELECT {BLOCK_COLUMNS}
350 FROM header AS h
351 JOIN payload AS p ON (h.payload_hash, h.ns_table) = (p.hash, p.ns_table)
352 JOIN transactions AS t ON t.block_height = h.height
353 WHERE t.hash = {hash_param}
354 ORDER BY t.block_height, t.ns_id, t.position
355 LIMIT 1"
356 );
357 let row = query.query(&sql).fetch_one(self.as_mut()).await?;
358 Ok(BlockQueryData::from_row(&row)?)
359 }
360}
361
362impl<Mode> Transaction<Mode>
363where
364 Mode: TransactionMode,
365{
366 async fn load_namespaces<Types>(
367 &mut self,
368 height: u64,
369 payload_size: u64,
370 ) -> QueryResult<NamespaceMap<Types>>
371 where
372 Types: NodeType,
373 Header<Types>: QueryableHeader<Types>,
374 Payload<Types>: QueryablePayload<Types>,
375 {
376 let header = self
377 .get_header(BlockId::<Types>::from(height as usize))
378 .await?;
379 let map = query(
380 "SELECT ns_id, ns_index, max(position) + 1 AS count
381 FROM transactions
382 WHERE block_height = $1
383 GROUP BY ns_id, ns_index",
384 )
385 .bind(height as i64)
386 .fetch(self.as_mut())
387 .map_ok(|row| {
388 let ns = row.get::<i64, _>("ns_index").into();
389 let id = row.get::<i64, _>("ns_id").into();
390 let num_transactions = row.get::<i64, _>("count") as u64;
391 let size = header.namespace_size(&ns, payload_size as usize);
392 (
393 id,
394 NamespaceInfo {
395 num_transactions,
396 size,
397 },
398 )
399 })
400 .try_collect()
401 .await?;
402 Ok(map)
403 }
404}
405
406#[cfg(test)]
407mod test {
408 use hotshot_example_types::node_types::TEST_VERSIONS;
409 use hotshot_types::{data::VidCommon, vid::advz::advz_scheme};
410 use jf_advz::VidScheme;
411 use pretty_assertions::assert_eq;
412
413 use super::*;
414 use crate::{
415 data_source::{
416 Transaction, VersionedDataSource,
417 sql::testing::TmpDb,
418 storage::{SqlStorage, StorageConnectionType, UpdateAvailabilityStorage},
419 },
420 testing::mocks::MockTypes,
421 };
422
423 #[tokio::test]
424 #[test_log::test]
425 async fn test_duplicate_payload() {
426 let storage = TmpDb::init().await;
427 let db = SqlStorage::connect(storage.config(), StorageConnectionType::Query)
428 .await
429 .unwrap();
430 let mut vid = advz_scheme(2);
431
432 let mut leaves = vec![
434 LeafQueryData::<MockTypes>::genesis(
435 &Default::default(),
436 &Default::default(),
437 TEST_VERSIONS.test,
438 )
439 .await,
440 ];
441 let mut blocks = vec![
442 BlockQueryData::<MockTypes>::genesis(
443 &Default::default(),
444 &Default::default(),
445 TEST_VERSIONS.test.base,
446 )
447 .await,
448 ];
449 let dispersal = vid.disperse([]).unwrap();
450 let mut vid = vec![VidCommonQueryData::<MockTypes>::new(
451 leaves[0].header().clone(),
452 VidCommon::V0(dispersal.common.clone()),
453 )];
454
455 let mut leaf = leaves[0].clone();
456 leaf.leaf.block_header_mut().block_number += 1;
457 let block = BlockQueryData::new(leaf.header().clone(), blocks[0].payload().clone());
458 let common =
459 VidCommonQueryData::new(leaf.header().clone(), VidCommon::V0(dispersal.common));
460 leaves.push(leaf);
461 blocks.push(block);
462 vid.push(common);
463
464 {
466 let mut tx = db.write().await.unwrap();
467 tx.insert_leaf(&leaves[0]).await.unwrap();
468 tx.commit().await.unwrap();
469 }
470
471 {
473 let mut tx = db.read().await.unwrap();
474 assert_eq!(tx.get_leaf(LeafId::Number(0)).await.unwrap(), leaves[0]);
475 assert_absent(
476 tx.get_block(BlockId::<MockTypes>::Number(0))
477 .await
478 .unwrap_err(),
479 );
480 assert_absent(
481 tx.get_vid_common(BlockId::<MockTypes>::Number(0))
482 .await
483 .unwrap_err(),
484 );
485 }
486
487 {
489 let mut tx = db.write().await.unwrap();
490 tx.insert_leaf(&leaves[1]).await.unwrap();
491 tx.insert_block(&blocks[1]).await.unwrap();
492 tx.insert_vid(&vid[1], None).await.unwrap();
493 tx.commit().await.unwrap();
494 }
495
496 for i in 0..2 {
498 let mut tx = db.read().await.unwrap();
499 assert_eq!(tx.get_leaf(LeafId::Number(i)).await.unwrap(), leaves[i]);
500 assert_eq!(tx.get_block(BlockId::Number(i)).await.unwrap(), blocks[i]);
501 assert_eq!(tx.get_vid_common(BlockId::Number(i)).await.unwrap(), vid[i]);
502 }
503 }
504
505 #[tokio::test]
506 #[test_log::test]
507 async fn test_same_payload_different_ns_table() {
508 let storage = TmpDb::init().await;
509 let db = SqlStorage::connect(storage.config(), StorageConnectionType::Query)
510 .await
511 .unwrap();
512 let mut vid = advz_scheme(2);
513
514 let mut leaves = vec![
518 LeafQueryData::<MockTypes>::genesis(
519 &Default::default(),
520 &Default::default(),
521 TEST_VERSIONS.test,
522 )
523 .await,
524 ];
525 let mut blocks = vec![
526 BlockQueryData::<MockTypes>::genesis(
527 &Default::default(),
528 &Default::default(),
529 TEST_VERSIONS.test.base,
530 )
531 .await,
532 ];
533 let dispersal = vid.disperse([]).unwrap();
534 let mut vid = vec![VidCommonQueryData::<MockTypes>::new(
535 leaves[0].header().clone(),
536 VidCommon::V0(dispersal.common.clone()),
537 )];
538
539 let mut leaf = leaves[0].clone();
540 leaf.leaf.block_header_mut().block_number += 1;
541 leaf.leaf.block_header_mut().metadata.num_transactions += 1;
542 let block = BlockQueryData::new(leaf.header().clone(), blocks[0].payload().clone());
543 let common =
544 VidCommonQueryData::new(leaf.header().clone(), VidCommon::V0(dispersal.common));
545 leaves.push(leaf);
546 blocks.push(block);
547 vid.push(common);
548
549 {
551 let mut tx = db.write().await.unwrap();
552 tx.insert_leaf(&leaves[0]).await.unwrap();
553 tx.commit().await.unwrap();
554 }
555
556 {
558 let mut tx = db.read().await.unwrap();
559 assert_eq!(tx.get_leaf(LeafId::Number(0)).await.unwrap(), leaves[0]);
560 assert_absent(
561 tx.get_block(BlockId::<MockTypes>::Number(0))
562 .await
563 .unwrap_err(),
564 );
565 assert_absent(
566 tx.get_vid_common(BlockId::<MockTypes>::Number(0))
567 .await
568 .unwrap_err(),
569 );
570 }
571
572 {
574 let mut tx = db.write().await.unwrap();
575 tx.insert_leaf(&leaves[1]).await.unwrap();
576 tx.insert_block(&blocks[1]).await.unwrap();
577 tx.insert_vid(&vid[1], None).await.unwrap();
578 tx.commit().await.unwrap();
579 }
580
581 let mut tx = db.read().await.unwrap();
583 for i in 0..2 {
584 assert_eq!(tx.get_leaf(LeafId::Number(i)).await.unwrap(), leaves[i]);
585 assert_eq!(tx.get_vid_common(BlockId::Number(i)).await.unwrap(), vid[i]);
586 }
587
588 assert_absent(
590 tx.get_block(BlockId::<MockTypes>::Number(0))
591 .await
592 .unwrap_err(),
593 );
594 assert_eq!(tx.get_block(BlockId::Number(1)).await.unwrap(), blocks[1]);
595 }
596
597 fn assert_absent(err: QueryError) {
598 assert!(
599 matches!(err, QueryError::Missing | QueryError::NotFound),
600 "{err:#}"
601 );
602 }
603}