hotshot_query_service/fetching/provider/
any.rs1use std::{fmt::Debug, sync::Arc};
14
15use async_trait::async_trait;
16use derivative::Derivative;
17use hotshot_types::{data::VidCommon, traits::node_implementation::NodeType};
18
19use super::{Provider, Request};
20use crate::{
21 Payload,
22 availability::{BlockQueryData, LeafQueryData, VidCommonQueryData},
23 data_source::AvailabilityProvider,
24 fetching::{
25 NonEmptyRange,
26 request::{
27 BlockRangeRequest, LeafRangeRequest, LeafRequest, PayloadRequest,
28 VidCommonRangeRequest, VidCommonRequest,
29 },
30 },
31};
32
33trait DebugProvider<Types, T>: Provider<Types, T> + Debug
39where
40 Types: NodeType,
41 T: Request<Types>,
42{
43}
44
45impl<Types, T, P> DebugProvider<Types, T> for P
46where
47 Types: NodeType,
48 T: Request<Types>,
49 P: Provider<Types, T> + Debug,
50{
51}
52
53type PayloadProvider<Types> = Arc<dyn DebugProvider<Types, PayloadRequest>>;
54type PayloadRangeProvider<Types> = Arc<dyn DebugProvider<Types, BlockRangeRequest>>;
55type LeafProvider<Types> = Arc<dyn DebugProvider<Types, LeafRequest<Types>>>;
56type LeafRangeProvider<Types> = Arc<dyn DebugProvider<Types, LeafRangeRequest<Types>>>;
57type VidCommonProvider<Types> = Arc<dyn DebugProvider<Types, VidCommonRequest>>;
58type VidCommonRangeProvider<Types> = Arc<dyn DebugProvider<Types, VidCommonRangeRequest>>;
59
60#[derive(Derivative)]
99#[derivative(Clone(bound = ""), Debug(bound = ""), Default(bound = ""))]
100pub struct AnyProvider<Types>
101where
102 Types: NodeType,
103{
104 payload_providers: Vec<PayloadProvider<Types>>,
105 payload_range_providers: Vec<PayloadRangeProvider<Types>>,
106 leaf_providers: Vec<LeafProvider<Types>>,
107 leaf_range_providers: Vec<LeafRangeProvider<Types>>,
108 vid_common_providers: Vec<VidCommonProvider<Types>>,
109 vid_common_range_providers: Vec<VidCommonRangeProvider<Types>>,
110}
111
112#[async_trait]
113impl<Types> Provider<Types, PayloadRequest> for AnyProvider<Types>
114where
115 Types: NodeType,
116{
117 async fn fetch(&self, req: PayloadRequest) -> Option<Payload<Types>> {
118 any_fetch(&self.payload_providers, req).await
119 }
120}
121
122#[async_trait]
123impl<Types> Provider<Types, BlockRangeRequest> for AnyProvider<Types>
124where
125 Types: NodeType,
126{
127 async fn fetch(&self, req: BlockRangeRequest) -> Option<NonEmptyRange<BlockQueryData<Types>>> {
128 any_fetch(&self.payload_range_providers, req).await
129 }
130}
131
132#[async_trait]
133impl<Types> Provider<Types, LeafRequest<Types>> for AnyProvider<Types>
134where
135 Types: NodeType,
136{
137 async fn fetch(&self, req: LeafRequest<Types>) -> Option<LeafQueryData<Types>> {
138 any_fetch(&self.leaf_providers, req).await
139 }
140}
141
142#[async_trait]
143impl<Types> Provider<Types, LeafRangeRequest<Types>> for AnyProvider<Types>
144where
145 Types: NodeType,
146{
147 async fn fetch(
148 &self,
149 req: LeafRangeRequest<Types>,
150 ) -> Option<NonEmptyRange<LeafQueryData<Types>>> {
151 any_fetch(&self.leaf_range_providers, req).await
152 }
153}
154
155#[async_trait]
156impl<Types> Provider<Types, VidCommonRequest> for AnyProvider<Types>
157where
158 Types: NodeType,
159{
160 async fn fetch(&self, req: VidCommonRequest) -> Option<VidCommon> {
161 any_fetch(&self.vid_common_providers, req).await
162 }
163}
164
165#[async_trait]
166impl<Types> Provider<Types, VidCommonRangeRequest> for AnyProvider<Types>
167where
168 Types: NodeType,
169{
170 async fn fetch(
171 &self,
172 req: VidCommonRangeRequest,
173 ) -> Option<NonEmptyRange<VidCommonQueryData<Types>>> {
174 any_fetch(&self.vid_common_range_providers, req).await
175 }
176}
177
178impl<Types> AnyProvider<Types>
179where
180 Types: NodeType,
181{
182 pub fn with_provider<P>(mut self, provider: P) -> Self
184 where
185 P: AvailabilityProvider<Types> + Debug + 'static,
186 {
187 let provider = Arc::new(provider);
188 self.payload_providers.push(provider.clone());
189 self.payload_range_providers.push(provider.clone());
190 self.leaf_providers.push(provider.clone());
191 self.leaf_range_providers.push(provider.clone());
192 self.vid_common_providers.push(provider.clone());
193 self.vid_common_range_providers.push(provider);
194 self
195 }
196
197 pub fn with_block_provider<P>(mut self, provider: P) -> Self
199 where
200 P: Provider<Types, PayloadRequest> + Debug + 'static,
201 {
202 self.payload_providers.push(Arc::new(provider));
203 self
204 }
205
206 pub fn with_block_range_provider<P>(mut self, provider: P) -> Self
208 where
209 P: Provider<Types, BlockRangeRequest> + Debug + 'static,
210 {
211 self.payload_range_providers.push(Arc::new(provider));
212 self
213 }
214
215 pub fn with_leaf_provider<P>(mut self, provider: P) -> Self
217 where
218 P: Provider<Types, LeafRequest<Types>> + Debug + 'static,
219 {
220 self.leaf_providers.push(Arc::new(provider));
221 self
222 }
223
224 pub fn with_leaf_range_provider<P>(mut self, provider: P) -> Self
226 where
227 P: Provider<Types, LeafRangeRequest<Types>> + Debug + 'static,
228 {
229 self.leaf_range_providers.push(Arc::new(provider));
230 self
231 }
232
233 pub fn with_vid_common_provider<P>(mut self, provider: P) -> Self
235 where
236 P: Provider<Types, VidCommonRequest> + Debug + 'static,
237 {
238 self.vid_common_providers.push(Arc::new(provider));
239 self
240 }
241
242 pub fn with_vid_common_range_provider<P>(mut self, provider: P) -> Self
244 where
245 P: Provider<Types, VidCommonRangeRequest> + Debug + 'static,
246 {
247 self.vid_common_range_providers.push(Arc::new(provider));
248 self
249 }
250}
251
252async fn any_fetch<Types, P, T>(providers: &[Arc<P>], req: T) -> Option<T::Response>
253where
254 Types: NodeType,
255 P: Provider<Types, T> + Debug + ?Sized,
256 T: Request<Types>,
257{
258 for (i, p) in providers.iter().enumerate() {
265 match p.fetch(req).await {
266 Some(obj) => return Some(obj),
267 None => {
268 tracing::debug!(
269 "failed to fetch {req:?} from provider {i}/{}: {p:?}",
270 providers.len()
271 );
272 continue;
273 },
274 }
275 }
276
277 tracing::warn!(
278 "failed to fetch {req:?} from all {} providers",
279 providers.len()
280 );
281
282 None
283}
284
285#[cfg(all(test, not(target_os = "windows")))]
287mod test {
288 use futures::stream::StreamExt;
289 use test_utils::reserve_tcp_port;
290 use tide_disco::App;
291 use vbs::version::StaticVersionType;
292
293 use super::*;
294 use crate::{
295 ApiState, Error,
296 availability::{AvailabilityDataSource, UpdateAvailabilityData, define_api},
297 data_source::storage::sql::testing::TmpDb,
298 fetching::provider::{NoFetching, QueryServiceProvider},
299 task::BackgroundTask,
300 testing::{
301 consensus::{MockDataSource, MockNetwork},
302 mocks::{MockBase, MockTypes},
303 },
304 types::HeightIndexed,
305 };
306
307 type Provider = AnyProvider<MockTypes>;
308
309 #[test_log::test(tokio::test(flavor = "multi_thread"))]
310 async fn test_fetch_first_provider_fails() {
311 let mut network = MockNetwork::<MockDataSource>::init().await;
313
314 let port = reserve_tcp_port().unwrap();
316 let mut app = App::<_, Error>::with_state(ApiState::from(network.data_source()));
317 app.register_module(
318 "availability",
319 define_api(
320 &Default::default(),
321 MockBase::instance(),
322 "1.0.0".parse().unwrap(),
323 )
324 .unwrap(),
325 )
326 .unwrap();
327 let _server = BackgroundTask::spawn(
328 "server",
329 app.serve(format!("0.0.0.0:{port}"), MockBase::instance()),
330 );
331
332 let db = TmpDb::init().await;
334 let provider =
335 Provider::default()
336 .with_provider(NoFetching)
337 .with_provider(QueryServiceProvider::new(
338 format!("http://localhost:{port}").parse().unwrap(),
339 MockBase::instance(),
340 ));
341 let data_source = db.config().connect(provider.clone()).await.unwrap();
342
343 network.start().await;
345
346 let leaves = network.data_source().subscribe_leaves(1).await;
349 let leaves = leaves.take(3).collect::<Vec<_>>().await;
350 let test_leaf = &leaves[0];
351 let test_payload = &leaves[1];
352
353 data_source
356 .append(leaves.last().cloned().unwrap().into())
357 .await
358 .unwrap();
359
360 tracing::info!("requesting leaf from multiple providers");
361 let leaf = data_source
362 .get_leaf(test_leaf.height() as usize)
363 .await
364 .await;
365 assert_eq!(leaf, *test_leaf);
366
367 tracing::info!("requesting payload from multiple providers");
368 let payload = data_source
369 .get_payload(test_payload.height() as usize)
370 .await
371 .await;
372 assert_eq!(payload.height(), test_payload.height());
373 assert_eq!(payload.block_hash(), test_payload.block_hash());
374 assert_eq!(payload.hash(), test_payload.payload_hash());
375 }
376}