hotshot_query_service/fetching/provider/
any.rs1use std::{fmt::Debug, sync::Arc};
14
15use async_trait::async_trait;
16use derivative::Derivative;
17use hotshot_types::{data::VidCommon, traits::node_implementation::NodeType};
18
19use super::{Provider, Request};
20use crate::{
21 Payload,
22 availability::{BlockQueryData, Certificate2, LeafQueryData, VidCommonQueryData},
23 data_source::AvailabilityProvider,
24 fetching::{
25 NonEmptyRange,
26 request::{
27 BlockRangeRequest, Certificate2Request, LeafRangeRequest, LeafRequest, PayloadRequest,
28 VidCommonRangeRequest, VidCommonRequest,
29 },
30 },
31};
32
33trait DebugProvider<Types, T>: Provider<Types, T> + Debug
39where
40 Types: NodeType,
41 T: Request<Types>,
42{
43}
44
45impl<Types, T, P> DebugProvider<Types, T> for P
46where
47 Types: NodeType,
48 T: Request<Types>,
49 P: Provider<Types, T> + Debug,
50{
51}
52
53type PayloadProvider<Types> = Arc<dyn DebugProvider<Types, PayloadRequest>>;
54type PayloadRangeProvider<Types> = Arc<dyn DebugProvider<Types, BlockRangeRequest>>;
55type LeafProvider<Types> = Arc<dyn DebugProvider<Types, LeafRequest>>;
56type LeafRangeProvider<Types> = Arc<dyn DebugProvider<Types, LeafRangeRequest>>;
57type VidCommonProvider<Types> = Arc<dyn DebugProvider<Types, VidCommonRequest>>;
58type VidCommonRangeProvider<Types> = Arc<dyn DebugProvider<Types, VidCommonRangeRequest>>;
59type Cert2Provider<Types> = Arc<dyn DebugProvider<Types, Certificate2Request>>;
60
61#[derive(Derivative)]
100#[derivative(Clone(bound = ""), Debug(bound = ""), Default(bound = ""))]
101pub struct AnyProvider<Types>
102where
103 Types: NodeType,
104{
105 payload_providers: Vec<PayloadProvider<Types>>,
106 payload_range_providers: Vec<PayloadRangeProvider<Types>>,
107 leaf_providers: Vec<LeafProvider<Types>>,
108 leaf_range_providers: Vec<LeafRangeProvider<Types>>,
109 vid_common_providers: Vec<VidCommonProvider<Types>>,
110 vid_common_range_providers: Vec<VidCommonRangeProvider<Types>>,
111 cert2_providers: Vec<Cert2Provider<Types>>,
112}
113
114#[async_trait]
115impl<Types> Provider<Types, PayloadRequest> for AnyProvider<Types>
116where
117 Types: NodeType,
118{
119 async fn fetch(&self, req: PayloadRequest) -> Option<Payload<Types>> {
120 any_fetch(&self.payload_providers, req).await
121 }
122}
123
124#[async_trait]
125impl<Types> Provider<Types, BlockRangeRequest> for AnyProvider<Types>
126where
127 Types: NodeType,
128{
129 async fn fetch(&self, req: BlockRangeRequest) -> Option<NonEmptyRange<BlockQueryData<Types>>> {
130 any_fetch(&self.payload_range_providers, req).await
131 }
132}
133
134#[async_trait]
135impl<Types> Provider<Types, LeafRequest> for AnyProvider<Types>
136where
137 Types: NodeType,
138{
139 async fn fetch(&self, req: LeafRequest) -> Option<LeafQueryData<Types>> {
140 any_fetch(&self.leaf_providers, req).await
141 }
142}
143
144#[async_trait]
145impl<Types> Provider<Types, LeafRangeRequest> for AnyProvider<Types>
146where
147 Types: NodeType,
148{
149 async fn fetch(&self, req: LeafRangeRequest) -> Option<NonEmptyRange<LeafQueryData<Types>>> {
150 any_fetch(&self.leaf_range_providers, req).await
151 }
152}
153
154#[async_trait]
155impl<Types> Provider<Types, VidCommonRequest> for AnyProvider<Types>
156where
157 Types: NodeType,
158{
159 async fn fetch(&self, req: VidCommonRequest) -> Option<VidCommon> {
160 any_fetch(&self.vid_common_providers, req).await
161 }
162}
163
164#[async_trait]
165impl<Types> Provider<Types, VidCommonRangeRequest> for AnyProvider<Types>
166where
167 Types: NodeType,
168{
169 async fn fetch(
170 &self,
171 req: VidCommonRangeRequest,
172 ) -> Option<NonEmptyRange<VidCommonQueryData<Types>>> {
173 any_fetch(&self.vid_common_range_providers, req).await
174 }
175}
176
177#[async_trait]
178impl<Types> Provider<Types, Certificate2Request> for AnyProvider<Types>
179where
180 Types: NodeType,
181{
182 async fn fetch(&self, req: Certificate2Request) -> Option<Option<Certificate2<Types>>> {
183 any_fetch(&self.cert2_providers, req).await
184 }
185}
186
187impl<Types> AnyProvider<Types>
188where
189 Types: NodeType,
190{
191 pub fn with_provider<P>(mut self, provider: P) -> Self
193 where
194 P: AvailabilityProvider<Types> + Debug + 'static,
195 {
196 let provider = Arc::new(provider);
197 self.payload_providers.push(provider.clone());
198 self.payload_range_providers.push(provider.clone());
199 self.leaf_providers.push(provider.clone());
200 self.leaf_range_providers.push(provider.clone());
201 self.vid_common_providers.push(provider.clone());
202 self.vid_common_range_providers.push(provider.clone());
203 self.cert2_providers.push(provider);
204 self
205 }
206
207 pub fn with_block_provider<P>(mut self, provider: P) -> Self
209 where
210 P: Provider<Types, PayloadRequest> + Debug + 'static,
211 {
212 self.payload_providers.push(Arc::new(provider));
213 self
214 }
215
216 pub fn with_block_range_provider<P>(mut self, provider: P) -> Self
218 where
219 P: Provider<Types, BlockRangeRequest> + Debug + 'static,
220 {
221 self.payload_range_providers.push(Arc::new(provider));
222 self
223 }
224
225 pub fn with_leaf_provider<P>(mut self, provider: P) -> Self
227 where
228 P: Provider<Types, LeafRequest> + Debug + 'static,
229 {
230 self.leaf_providers.push(Arc::new(provider));
231 self
232 }
233
234 pub fn with_leaf_range_provider<P>(mut self, provider: P) -> Self
236 where
237 P: Provider<Types, LeafRangeRequest> + Debug + 'static,
238 {
239 self.leaf_range_providers.push(Arc::new(provider));
240 self
241 }
242
243 pub fn with_vid_common_provider<P>(mut self, provider: P) -> Self
245 where
246 P: Provider<Types, VidCommonRequest> + Debug + 'static,
247 {
248 self.vid_common_providers.push(Arc::new(provider));
249 self
250 }
251
252 pub fn with_vid_common_range_provider<P>(mut self, provider: P) -> Self
254 where
255 P: Provider<Types, VidCommonRangeRequest> + Debug + 'static,
256 {
257 self.vid_common_range_providers.push(Arc::new(provider));
258 self
259 }
260}
261
262async fn any_fetch<Types, P, T>(providers: &[Arc<P>], req: T) -> Option<T::Response>
263where
264 Types: NodeType,
265 P: Provider<Types, T> + Debug + ?Sized,
266 T: Request<Types>,
267{
268 for (i, p) in providers.iter().enumerate() {
275 match p.fetch(req).await {
276 Some(obj) => return Some(obj),
277 None => {
278 tracing::debug!(
279 "failed to fetch {req:?} from provider {i}/{}: {p:?}",
280 providers.len()
281 );
282 continue;
283 },
284 }
285 }
286
287 tracing::warn!(
288 "failed to fetch {req:?} from all {} providers",
289 providers.len()
290 );
291
292 None
293}
294
295#[cfg(all(test, not(target_os = "windows")))]
297mod test {
298 use futures::stream::StreamExt;
299 use test_utils::reserve_tcp_port;
300 use tide_disco::App;
301 use vbs::version::StaticVersionType;
302
303 use super::*;
304 use crate::{
305 ApiState, Error,
306 availability::{AvailabilityDataSource, UpdateAvailabilityData, define_api},
307 data_source::storage::sql::testing::TmpDb,
308 fetching::provider::{NoFetching, TrustedQueryServiceProvider},
309 task::BackgroundTask,
310 testing::{
311 consensus::{MockDataSource, MockNetwork},
312 mocks::{MockBase, MockTypes},
313 },
314 types::HeightIndexed,
315 };
316
317 type Provider = AnyProvider<MockTypes>;
318
319 #[test_log::test(tokio::test(flavor = "multi_thread"))]
320 async fn test_fetch_first_provider_fails() {
321 let mut network = MockNetwork::<MockDataSource>::init().await;
323
324 let port = reserve_tcp_port().unwrap();
326 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
327 app.register_module(
328 "availability",
329 define_api(
330 &Default::default(),
331 MockBase::instance(),
332 "1.0.0".parse().unwrap(),
333 )
334 .unwrap(),
335 )
336 .unwrap();
337 let _server = BackgroundTask::spawn(
338 "server",
339 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
340 );
341
342 let db = TmpDb::init().await;
344 let provider = Provider::default().with_provider(NoFetching).with_provider(
345 TrustedQueryServiceProvider::new(
346 format!("http://localhost:{port}").parse().unwrap(),
347 MockBase::instance(),
348 ),
349 );
350 let data_source = db.config().connect(provider.clone()).await.unwrap();
351
352 network.start().await;
354
355 let leaves = network.data_source().subscribe_leaves(1).await;
358 let leaves = leaves.take(3).collect::<Vec<_>>().await;
359 let test_leaf = &leaves[0];
360 let test_payload = &leaves[1];
361
362 data_source
365 .append(leaves.last().cloned().unwrap().into())
366 .await
367 .unwrap();
368
369 tracing::info!("requesting leaf from multiple providers");
370 let leaf = data_source
371 .get_leaf(test_leaf.height() as usize)
372 .await
373 .await;
374 assert_eq!(leaf, *test_leaf);
375
376 tracing::info!("requesting payload from multiple providers");
377 let payload = data_source
378 .get_payload(test_payload.height() as usize)
379 .await
380 .await;
381 assert_eq!(payload.height(), test_payload.height());
382 assert_eq!(payload.block_hash(), test_payload.block_hash());
383 assert_eq!(payload.hash(), test_payload.payload_hash());
384 }
385}