Skip to main content

espresso_api/
axum.rs

1//! Axum HTTP/JSON API handlers
2
3pub mod routes;
4
5use aide::{
6    axum::{ApiRouter, routing::get_with},
7    openapi::{Info, OpenApi},
8    operation::OperationOutput,
9    redoc::Redoc,
10    scalar::Scalar,
11};
12use axum::{
13    Extension, Json, Router,
14    extract::{Path, Request, State, ws::WebSocketUpgrade},
15    http::{StatusCode, Uri},
16    middleware::{self, Next},
17    response::{Html, IntoResponse, Response},
18    routing::get,
19};
20use futures::stream::BoxStream;
21use schemars::transform::Transform;
22use serde::Serialize;
23use serialization_api::v2::{
24    GetIncorrectEncodingProofRequest, GetNamespaceProofRequest, GetRewardAccountProofRequest,
25    GetRewardBalanceRequest, GetRewardBalancesRequest, GetRewardClaimInputRequest,
26    GetRewardMerkleTreeRequest, GetStakeTableRequest, GetStateCertificateRequest,
27};
28
29use crate::{
30    error::{ApiError, AvailabilityError},
31    handlers, v1, v2,
32};
33
34/// API error response
35#[derive(Debug, Serialize, schemars::JsonSchema)]
36struct ErrorResponse {
37    error: String,
38}
39
40impl IntoResponse for ApiError {
41    fn into_response(self) -> Response {
42        let status = match &self {
43            ApiError::BadRequest(_) => StatusCode::BAD_REQUEST,
44            ApiError::NotFound(_) => StatusCode::NOT_FOUND,
45            ApiError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
46        };
47
48        let body = Json(ErrorResponse {
49            error: self.to_string(),
50        });
51
52        (status, body).into_response()
53    }
54}
55
56/// Classify an `anyhow::Error` from an availability handler into the appropriate `ApiError`
57/// variant. Errors produced via [`AvailabilityError`] in the state implementation carry semantic
58/// meaning; everything else falls back to a 500 Internal Server Error.
59fn classify_availability_error(err: anyhow::Error) -> ApiError {
60    let is_not_found = err
61        .downcast_ref::<AvailabilityError>()
62        .map(|e| matches!(e, AvailabilityError::NotFound(_)));
63    match is_not_found {
64        Some(true) => ApiError::NotFound(err),
65        Some(false) => ApiError::BadRequest(err),
66        None => ApiError::Internal(err),
67    }
68}
69
70impl OperationOutput for ApiError {
71    type Inner = Self;
72}
73
74/// Serve the OpenAPI spec (extracted from Extension)
75async fn serve_openapi_spec(Extension(api): Extension<OpenApi>) -> Json<OpenApi> {
76    Json(api)
77}
78
79/// Serve custom Swagger UI with collapsed defaults
80async fn serve_swagger_ui() -> Html<&'static str> {
81    Html(include_str!("../templates/swagger.html"))
82}
83
84/// Middleware to rewrite root paths to /v2 paths
85///
86/// Requests to `/rewards/...` get rewritten to `/v2/rewards/...`
87/// Paths already prefixed with `/v2` are left unchanged
88///
89/// Note: This middleware is only applied to the v2 router, so v1 routes never pass through it
90async fn rewrite_root_to_v2(mut req: Request, next: Next) -> Response {
91    let uri = req.uri().clone();
92    let path = uri.path();
93
94    // Only rewrite unversioned paths (not starting with /v2)
95    if !path.starts_with("/v2") && path != "/" {
96        let new_path = format!("/v2{}", path);
97        let pq = if let Some(q) = uri.query() {
98            format!("{}?{}", new_path, q)
99        } else {
100            new_path
101        };
102        if let Ok(new_uri) = Uri::builder().path_and_query(pq).build() {
103            *req.uri_mut() = new_uri;
104        }
105    }
106
107    next.run(req).await
108}
109
110/// Redirect handler for root path
111async fn redirect_to_docs() -> axum::response::Redirect {
112    axum::response::Redirect::permanent("/v2")
113}
114
115struct SendQuery<T>(T);
116
117impl<T, S> axum::extract::FromRequestParts<S> for SendQuery<T>
118where
119    T: serde::de::DeserializeOwned + Send,
120    S: Send + Sync,
121{
122    type Rejection = axum::extract::rejection::QueryRejection;
123
124    async fn from_request_parts(
125        parts: &mut axum::http::request::Parts,
126        state: &S,
127    ) -> Result<Self, Self::Rejection> {
128        axum::extract::Query::<T>::from_request_parts(parts, state)
129            .await
130            .map(|axum::extract::Query(inner)| SendQuery(inner))
131    }
132}
133
134impl<T: schemars::JsonSchema> aide::operation::OperationInput for SendQuery<T> {
135    fn operation_input(
136        ctx: &mut aide::generate::GenContext,
137        operation: &mut aide::openapi::Operation,
138    ) {
139        let schema = ctx.schema.subschema_for::<T>();
140        let params = aide::operation::parameters_from_schema(
141            ctx,
142            schema,
143            aide::operation::ParamLocation::Query,
144        );
145        aide::operation::add_parameters(ctx, operation, params);
146    }
147}
148
149async fn drive_ws_stream<T: Serialize>(
150    mut socket: axum::extract::ws::WebSocket,
151    stream: BoxStream<'static, T>,
152) {
153    use futures::StreamExt as _;
154    futures::pin_mut!(stream);
155    while let Some(item) = stream.next().await {
156        let Ok(json) = serde_json::to_string(&item) else {
157            break;
158        };
159        if socket
160            .send(axum::extract::ws::Message::Text(json.into()))
161            .await
162            .is_err()
163        {
164            break;
165        }
166    }
167}
168
169/// Create a combined router serving both v1 and v2 APIs
170pub fn create_combined_router<S>(state: S) -> Router
171where
172    S: v1::RewardApi
173        + v1::AvailabilityApi
174        + v1::HotShotAvailabilityApi
175        + v2::RewardApi
176        + v2::DataApi
177        + v2::ConsensusApi
178        + Clone
179        + Send
180        + Sync
181        + 'static,
182{
183    let router_v1 = create_router_v1(state.clone());
184    let router_v2 = create_router_v2(state).layer(middleware::from_fn(rewrite_root_to_v2));
185
186    router_v2.merge(router_v1).route("/", get(redirect_to_docs))
187}
188
189/// Create v1 router without OpenAPI documentation (internal types)
190pub fn create_router_v1<S>(state: S) -> Router
191where
192    S: v1::RewardApi
193        + v1::AvailabilityApi
194        + v1::HotShotAvailabilityApi
195        + Clone
196        + Send
197        + Sync
198        + 'static,
199{
200    // Create handler closures that capture the generic state type
201    let get_reward_claim_input =
202        |State(state): State<S>, Path((height, address)): Path<(u64, String)>| async move {
203            state
204                .get_reward_claim_input(height, address)
205                .await
206                .map(Json)
207                .map_err(ApiError::Internal)
208        };
209
210    let get_reward_balance =
211        |State(state): State<S>, Path((height, address)): Path<(u64, String)>| async move {
212            state
213                .get_reward_balance(height, address)
214                .await
215                .map(Json)
216                .map_err(ApiError::Internal)
217        };
218
219    let get_latest_reward_balance = |State(state): State<S>, Path(address): Path<String>| async move {
220        state
221            .get_latest_reward_balance(address)
222            .await
223            .map(Json)
224            .map_err(ApiError::Internal)
225    };
226
227    let get_reward_account_proof =
228        |State(state): State<S>, Path((height, address)): Path<(u64, String)>| async move {
229            state
230                .get_reward_account_proof(height, address)
231                .await
232                .map(Json)
233                .map_err(ApiError::Internal)
234        };
235
236    let get_latest_reward_account_proof = |State(state): State<S>, Path(address): Path<String>| async move {
237        state
238            .get_latest_reward_account_proof(address)
239            .await
240            .map(Json)
241            .map_err(ApiError::Internal)
242    };
243
244    let get_reward_amounts =
245        |State(state): State<S>, Path((height, offset, limit)): Path<(u64, u64, u64)>| async move {
246            state
247                .get_reward_amounts(height, offset, limit)
248                .await
249                .map(Json)
250                .map_err(ApiError::Internal)
251        };
252
253    let get_reward_merkle_tree_v2 = |State(state): State<S>, Path(height): Path<u64>| async move {
254        state
255            .get_reward_merkle_tree_v2(height)
256            .await
257            .map(Json)
258            .map_err(ApiError::Internal)
259    };
260
261    // Availability API handlers
262    // Route: /v1/availability/block/{height}/namespace/{namespace}
263    let get_namespace_proof_by_height =
264        |State(state): State<S>, Path((height, namespace)): Path<(u64, u32)>| async move {
265            state
266                .get_namespace_proof(v1::availability::BlockId::Height(height), namespace)
267                .await
268                .map(Json)
269                .map_err(classify_availability_error)
270        };
271
272    // Route: /v1/availability/block/hash/{hash}/namespace/{namespace}
273    let get_namespace_proof_by_hash =
274        |State(state): State<S>, Path((hash, namespace)): Path<(String, u32)>| async move {
275            state
276                .get_namespace_proof(v1::availability::BlockId::Hash(hash), namespace)
277                .await
278                .map(Json)
279                .map_err(classify_availability_error)
280        };
281
282    // Route: /v1/availability/block/payload-hash/{payload-hash}/namespace/{namespace}
283    let get_namespace_proof_by_payload_hash =
284        |State(state): State<S>, Path((payload_hash, namespace)): Path<(String, u32)>| async move {
285            state
286                .get_namespace_proof(
287                    v1::availability::BlockId::PayloadHash(payload_hash),
288                    namespace,
289                )
290                .await
291                .map(Json)
292                .map_err(classify_availability_error)
293        };
294
295    // Route: /v1/availability/block/{from}/{until}/namespace/{namespace}
296    let get_namespace_proof_range =
297        |State(state): State<S>, Path((from, until, namespace)): Path<(u64, u64, u32)>| async move {
298            state
299                .get_namespace_proof_range(from, until, namespace)
300                .await
301                .map(Json)
302                .map_err(classify_availability_error)
303        };
304
305    let get_incorrect_encoding_proof =
306        |State(state): State<S>, Path((block_number, namespace)): Path<(u64, u32)>| async move {
307            state
308                .get_incorrect_encoding_proof(
309                    v1::availability::BlockId::Height(block_number),
310                    namespace,
311                )
312                .await
313                .map(Json)
314                .map_err(classify_availability_error)
315        };
316
317    let get_state_cert_v1 = |State(state): State<S>, Path(epoch): Path<u64>| async move {
318        state
319            .get_state_cert(epoch)
320            .await
321            .map(Json)
322            .map_err(classify_availability_error)
323    };
324
325    let get_state_cert_v2 = |State(state): State<S>, Path(epoch): Path<u64>| async move {
326        state
327            .get_state_cert_v2(epoch)
328            .await
329            .map(Json)
330            .map_err(classify_availability_error)
331    };
332
333    // HotShot availability API handlers
334
335    let get_leaf_by_height = |State(state): State<S>, Path(height): Path<u64>| async move {
336        state
337            .get_leaf(v1::LeafId::Height(height))
338            .await
339            .map(Json)
340            .map_err(classify_availability_error)
341    };
342    let get_leaf_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
343        state
344            .get_leaf(v1::LeafId::Hash(hash))
345            .await
346            .map(Json)
347            .map_err(classify_availability_error)
348    };
349    let get_leaf_range = |State(state): State<S>, Path((from, until)): Path<(usize, usize)>| async move {
350        state
351            .get_leaf_range(from, until)
352            .await
353            .map(Json)
354            .map_err(classify_availability_error)
355    };
356
357    let get_header_by_height = |State(state): State<S>, Path(height): Path<u64>| async move {
358        state
359            .get_header(v1::BlockId::Height(height))
360            .await
361            .map(Json)
362            .map_err(classify_availability_error)
363    };
364    let get_header_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
365        state
366            .get_header(v1::BlockId::Hash(hash))
367            .await
368            .map(Json)
369            .map_err(classify_availability_error)
370    };
371    let get_header_by_payload_hash = |State(state): State<S>, Path(payload_hash): Path<String>| async move {
372        state
373            .get_header(v1::BlockId::PayloadHash(payload_hash))
374            .await
375            .map(Json)
376            .map_err(classify_availability_error)
377    };
378    let get_header_range = |State(state): State<S>, Path((from, until)): Path<(usize, usize)>| async move {
379        state
380            .get_header_range(from, until)
381            .await
382            .map(Json)
383            .map_err(classify_availability_error)
384    };
385
386    let get_block_by_height = |State(state): State<S>, Path(height): Path<u64>| async move {
387        state
388            .get_block(v1::BlockId::Height(height))
389            .await
390            .map(Json)
391            .map_err(classify_availability_error)
392    };
393    let get_block_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
394        state
395            .get_block(v1::BlockId::Hash(hash))
396            .await
397            .map(Json)
398            .map_err(classify_availability_error)
399    };
400    let get_block_by_payload_hash = |State(state): State<S>, Path(payload_hash): Path<String>| async move {
401        state
402            .get_block(v1::BlockId::PayloadHash(payload_hash))
403            .await
404            .map(Json)
405            .map_err(classify_availability_error)
406    };
407    let get_block_range = |State(state): State<S>, Path((from, until)): Path<(usize, usize)>| async move {
408        state
409            .get_block_range(from, until)
410            .await
411            .map(Json)
412            .map_err(classify_availability_error)
413    };
414
415    let get_payload_by_height = |State(state): State<S>, Path(height): Path<u64>| async move {
416        state
417            .get_payload(v1::PayloadId::Height(height))
418            .await
419            .map(Json)
420            .map_err(classify_availability_error)
421    };
422    let get_payload_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
423        state
424            .get_payload(v1::PayloadId::Hash(hash))
425            .await
426            .map(Json)
427            .map_err(classify_availability_error)
428    };
429    let get_payload_by_block_hash = |State(state): State<S>, Path(block_hash): Path<String>| async move {
430        state
431            .get_payload(v1::PayloadId::BlockHash(block_hash))
432            .await
433            .map(Json)
434            .map_err(classify_availability_error)
435    };
436    let get_payload_range = |State(state): State<S>, Path((from, until)): Path<(usize, usize)>| async move {
437        state
438            .get_payload_range(from, until)
439            .await
440            .map(Json)
441            .map_err(classify_availability_error)
442    };
443
444    let get_vid_common_by_height = |State(state): State<S>, Path(height): Path<u64>| async move {
445        state
446            .get_vid_common(v1::BlockId::Height(height))
447            .await
448            .map(Json)
449            .map_err(classify_availability_error)
450    };
451    let get_vid_common_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
452        state
453            .get_vid_common(v1::BlockId::Hash(hash))
454            .await
455            .map(Json)
456            .map_err(classify_availability_error)
457    };
458    let get_vid_common_by_payload_hash =
459        |State(state): State<S>, Path(payload_hash): Path<String>| async move {
460            state
461                .get_vid_common(v1::BlockId::PayloadHash(payload_hash))
462                .await
463                .map(Json)
464                .map_err(classify_availability_error)
465        };
466    let get_vid_common_range =
467        |State(state): State<S>, Path((from, until)): Path<(usize, usize)>| async move {
468            state
469                .get_vid_common_range(from, until)
470                .await
471                .map(Json)
472                .map_err(classify_availability_error)
473        };
474
475    let get_transaction_by_position =
476        |State(state): State<S>, Path((height, index)): Path<(u64, u64)>| async move {
477            state
478                .get_transaction_by_position(height, index)
479                .await
480                .map(Json)
481                .map_err(classify_availability_error)
482        };
483    let get_transaction_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
484        state
485            .get_transaction_by_hash(hash)
486            .await
487            .map(Json)
488            .map_err(classify_availability_error)
489    };
490    let get_transaction_proof_by_position =
491        |State(state): State<S>, Path((height, index)): Path<(u64, u64)>| async move {
492            state
493                .get_transaction_proof_by_position(height, index)
494                .await
495                .map(Json)
496                .map_err(classify_availability_error)
497        };
498    let get_transaction_proof_by_hash = |State(state): State<S>, Path(hash): Path<String>| async move {
499        state
500            .get_transaction_proof_by_hash(hash)
501            .await
502            .map(Json)
503            .map_err(classify_availability_error)
504    };
505
506    let get_block_summary_by_height = |State(state): State<S>, Path(height): Path<usize>| async move {
507        state
508            .get_block_summary(height)
509            .await
510            .map(Json)
511            .map_err(classify_availability_error)
512    };
513    let get_block_summary_range =
514        |State(state): State<S>, Path((from, until)): Path<(usize, usize)>| async move {
515            state
516                .get_block_summary_range(from, until)
517                .await
518                .map(Json)
519                .map_err(classify_availability_error)
520        };
521
522    let get_limits = |State(state): State<S>| async move {
523        state
524            .get_limits()
525            .await
526            .map(Json)
527            .map_err(ApiError::Internal)
528    };
529
530    let get_cert2 = |State(state): State<S>, Path(height): Path<u64>| async move {
531        state
532            .get_cert2(height)
533            .await
534            .map(Json)
535            .map_err(ApiError::Internal)
536    };
537
538    // WebSocket streaming handlers
539    let stream_leaves =
540        |ws: WebSocketUpgrade, State(state): State<S>, Path(height): Path<usize>| async move {
541            ws.on_upgrade(move |socket| async move {
542                match state.stream_leaves(height).await {
543                    Ok(stream) => drive_ws_stream(socket, stream).await,
544                    Err(e) => tracing::warn!("stream_leaves: {e}"),
545                }
546            })
547        };
548    let stream_headers =
549        |ws: WebSocketUpgrade, State(state): State<S>, Path(height): Path<usize>| async move {
550            ws.on_upgrade(move |socket| async move {
551                match state.stream_headers(height).await {
552                    Ok(stream) => drive_ws_stream(socket, stream).await,
553                    Err(e) => tracing::warn!("stream_headers: {e}"),
554                }
555            })
556        };
557    let stream_blocks =
558        |ws: WebSocketUpgrade, State(state): State<S>, Path(height): Path<usize>| async move {
559            ws.on_upgrade(move |socket| async move {
560                match state.stream_blocks(height).await {
561                    Ok(stream) => drive_ws_stream(socket, stream).await,
562                    Err(e) => tracing::warn!("stream_blocks: {e}"),
563                }
564            })
565        };
566    let stream_payloads =
567        |ws: WebSocketUpgrade, State(state): State<S>, Path(height): Path<usize>| async move {
568            ws.on_upgrade(move |socket| async move {
569                match state.stream_payloads(height).await {
570                    Ok(stream) => drive_ws_stream(socket, stream).await,
571                    Err(e) => tracing::warn!("stream_payloads: {e}"),
572                }
573            })
574        };
575    let stream_vid_common =
576        |ws: WebSocketUpgrade, State(state): State<S>, Path(height): Path<usize>| async move {
577            ws.on_upgrade(move |socket| async move {
578                match state.stream_vid_common(height).await {
579                    Ok(stream) => drive_ws_stream(socket, stream).await,
580                    Err(e) => tracing::warn!("stream_vid_common: {e}"),
581                }
582            })
583        };
584    let stream_transactions =
585        |ws: WebSocketUpgrade, State(state): State<S>, Path(height): Path<usize>| async move {
586            ws.on_upgrade(move |socket| async move {
587                match state.stream_transactions(height, None).await {
588                    Ok(stream) => drive_ws_stream(socket, stream).await,
589                    Err(e) => tracing::warn!("stream_transactions: {e}"),
590                }
591            })
592        };
593    let stream_transactions_ns =
594        |ws: WebSocketUpgrade,
595         State(state): State<S>,
596         Path((height, namespace)): Path<(usize, u32)>| async move {
597            ws.on_upgrade(move |socket| async move {
598                match state.stream_transactions(height, Some(namespace)).await {
599                    Ok(stream) => drive_ws_stream(socket, stream).await,
600                    Err(e) => tracing::warn!("stream_transactions_ns: {e}"),
601                }
602            })
603        };
604    let stream_namespace_proofs =
605        |ws: WebSocketUpgrade,
606         State(state): State<S>,
607         Path((height, namespace)): Path<(usize, u32)>| async move {
608            ws.on_upgrade(move |socket| async move {
609                match state.stream_namespace_proofs(height, namespace).await {
610                    Ok(stream) => drive_ws_stream(socket, stream).await,
611                    Err(e) => tracing::warn!("stream_namespace_proofs: {e}"),
612                }
613            })
614        };
615
616    // Build plain Axum router without OpenAPI (for v1 - internal types)
617    Router::new()
618        .route(
619            routes::v1::REWARD_CLAIM_INPUT_ROUTE,
620            get(get_reward_claim_input),
621        )
622        .route(routes::v1::REWARD_BALANCE_ROUTE, get(get_reward_balance))
623        .route(
624            routes::v1::LATEST_REWARD_BALANCE_ROUTE,
625            get(get_latest_reward_balance),
626        )
627        .route(
628            routes::v1::REWARD_ACCOUNT_PROOF_ROUTE,
629            get(get_reward_account_proof),
630        )
631        .route(
632            routes::v1::LATEST_REWARD_ACCOUNT_PROOF_ROUTE,
633            get(get_latest_reward_account_proof),
634        )
635        .route(routes::v1::REWARD_AMOUNTS_ROUTE, get(get_reward_amounts))
636        .route(
637            routes::v1::REWARD_MERKLE_TREE_V2_ROUTE,
638            get(get_reward_merkle_tree_v2),
639        )
640        // Availability API routes
641        .route(
642            routes::v1::NAMESPACE_PROOF_BY_HEIGHT_ROUTE,
643            get(get_namespace_proof_by_height),
644        )
645        .route(
646            routes::v1::NAMESPACE_PROOF_BY_HASH_ROUTE,
647            get(get_namespace_proof_by_hash),
648        )
649        .route(
650            routes::v1::NAMESPACE_PROOF_BY_PAYLOAD_HASH_ROUTE,
651            get(get_namespace_proof_by_payload_hash),
652        )
653        .route(
654            routes::v1::NAMESPACE_PROOF_RANGE_ROUTE,
655            get(get_namespace_proof_range),
656        )
657        .route(
658            routes::v1::INCORRECT_ENCODING_PROOF_ROUTE,
659            get(get_incorrect_encoding_proof),
660        )
661        .route(routes::v1::STATE_CERT_V1_ROUTE, get(get_state_cert_v1))
662        .route(routes::v1::STATE_CERT_V2_ROUTE, get(get_state_cert_v2))
663        // HotShot availability API routes
664        .route(routes::v1::LEAF_BY_HEIGHT_ROUTE, get(get_leaf_by_height))
665        .route(routes::v1::LEAF_BY_HASH_ROUTE, get(get_leaf_by_hash))
666        .route(routes::v1::LEAF_RANGE_ROUTE, get(get_leaf_range))
667        .route(
668            routes::v1::HEADER_BY_HEIGHT_ROUTE,
669            get(get_header_by_height),
670        )
671        .route(routes::v1::HEADER_BY_HASH_ROUTE, get(get_header_by_hash))
672        .route(
673            routes::v1::HEADER_BY_PAYLOAD_HASH_ROUTE,
674            get(get_header_by_payload_hash),
675        )
676        .route(routes::v1::HEADER_RANGE_ROUTE, get(get_header_range))
677        .route(routes::v1::BLOCK_BY_HEIGHT_ROUTE, get(get_block_by_height))
678        .route(routes::v1::BLOCK_BY_HASH_ROUTE, get(get_block_by_hash))
679        .route(
680            routes::v1::BLOCK_BY_PAYLOAD_HASH_ROUTE,
681            get(get_block_by_payload_hash),
682        )
683        .route(routes::v1::BLOCK_RANGE_ROUTE, get(get_block_range))
684        .route(
685            routes::v1::PAYLOAD_BY_HEIGHT_ROUTE,
686            get(get_payload_by_height),
687        )
688        .route(
689            routes::v1::PAYLOAD_BY_HASH_ROUTE,
690            get(get_payload_by_hash),
691        )
692        .route(
693            routes::v1::PAYLOAD_BY_BLOCK_HASH_ROUTE,
694            get(get_payload_by_block_hash),
695        )
696        .route(routes::v1::PAYLOAD_RANGE_ROUTE, get(get_payload_range))
697        .route(
698            routes::v1::VID_COMMON_BY_HEIGHT_ROUTE,
699            get(get_vid_common_by_height),
700        )
701        .route(
702            routes::v1::VID_COMMON_BY_HASH_ROUTE,
703            get(get_vid_common_by_hash),
704        )
705        .route(
706            routes::v1::VID_COMMON_BY_PAYLOAD_HASH_ROUTE,
707            get(get_vid_common_by_payload_hash),
708        )
709        .route(
710            routes::v1::VID_COMMON_RANGE_ROUTE,
711            get(get_vid_common_range),
712        )
713        .route(
714            routes::v1::TRANSACTION_BY_POSITION_NOPROOF_ROUTE,
715            get(get_transaction_by_position),
716        )
717        .route(
718            routes::v1::TRANSACTION_BY_HASH_NOPROOF_ROUTE,
719            get(get_transaction_by_hash),
720        )
721        .route(
722            routes::v1::TRANSACTION_PROOF_BY_POSITION_ROUTE,
723            get(get_transaction_proof_by_position),
724        )
725        .route(
726            routes::v1::TRANSACTION_PROOF_BY_HASH_ROUTE,
727            get(get_transaction_proof_by_hash),
728        )
729        .route(
730            routes::v1::TRANSACTION_BY_POSITION_ROUTE,
731            get(get_transaction_proof_by_position),
732        )
733        .route(
734            routes::v1::TRANSACTION_BY_HASH_ROUTE,
735            get(get_transaction_proof_by_hash),
736        )
737        .route(
738            routes::v1::BLOCK_SUMMARY_BY_HEIGHT_ROUTE,
739            get(get_block_summary_by_height),
740        )
741        .route(
742            routes::v1::BLOCK_SUMMARY_RANGE_ROUTE,
743            get(get_block_summary_range),
744        )
745        .route(routes::v1::LIMITS_ROUTE, get(get_limits))
746        .route(routes::v1::CERT2_BY_HEIGHT_ROUTE, get(get_cert2))
747        // WebSocket streaming routes
748        .route(routes::v1::STREAM_LEAVES_ROUTE, get(stream_leaves))
749        .route(routes::v1::STREAM_HEADERS_ROUTE, get(stream_headers))
750        .route(routes::v1::STREAM_BLOCKS_ROUTE, get(stream_blocks))
751        .route(routes::v1::STREAM_PAYLOADS_ROUTE, get(stream_payloads))
752        .route(routes::v1::STREAM_VID_COMMON_ROUTE, get(stream_vid_common))
753        .route(
754            routes::v1::STREAM_TRANSACTIONS_ROUTE,
755            get(stream_transactions),
756        )
757        .route(
758            routes::v1::STREAM_TRANSACTIONS_NS_ROUTE,
759            get(stream_transactions_ns),
760        )
761        .route(
762            routes::v1::STREAM_NAMESPACE_PROOFS_ROUTE,
763            get(stream_namespace_proofs),
764        )
765        .with_state(state)
766}
767
768/// Create v2 router with OpenAPI documentation (proto types)
769pub fn create_router_v2<S>(state: S) -> Router
770where
771    S: v2::RewardApi + v2::DataApi + v2::ConsensusApi + Clone + Send + Sync + 'static,
772{
773    let mut api = OpenApi {
774        info: Info {
775            title: "Espresso Node API v2".to_string(),
776            description: None,
777            version: "1.0.0".to_string(),
778            ..Default::default()
779        },
780        ..Default::default()
781    };
782
783    let get_reward_claim_input =
784        |State(state): State<S>, SendQuery(request): SendQuery<GetRewardClaimInputRequest>| async move {
785            handlers::get_reward_claim_input(&state, request)
786                .await
787                .map(Json)
788        };
789
790    let get_reward_balance =
791        |State(state): State<S>, SendQuery(request): SendQuery<GetRewardBalanceRequest>| async move {
792            handlers::get_reward_balance(&state, request)
793                .await
794                .map(Json)
795        };
796
797    let get_reward_account_proof =
798        |State(state): State<S>, SendQuery(request): SendQuery<GetRewardAccountProofRequest>| async move {
799            handlers::get_reward_account_proof(&state, request)
800                .await
801                .map(Json)
802        };
803
804    let get_reward_balances =
805        |State(state): State<S>, SendQuery(request): SendQuery<GetRewardBalancesRequest>| async move {
806            handlers::get_reward_balances(&state, request)
807                .await
808                .map(Json)
809        };
810
811    let get_reward_merkle_tree_v2 =
812        |State(state): State<S>, SendQuery(request): SendQuery<GetRewardMerkleTreeRequest>| async move {
813            handlers::get_reward_merkle_tree_v2(&state, request)
814                .await
815                .map(Json)
816        };
817
818    let get_state_certificate =
819        |State(state): State<S>, SendQuery(request): SendQuery<GetStateCertificateRequest>| async move {
820            handlers::get_state_certificate(&state, request)
821                .await
822                .map(Json)
823        };
824
825    let get_stake_table =
826        |State(state): State<S>, SendQuery(request): SendQuery<GetStakeTableRequest>| async move {
827            handlers::get_stake_table(&state, request).await.map(Json)
828        };
829
830    let get_namespace_proof =
831        |State(state): State<S>, SendQuery(query): SendQuery<GetNamespaceProofRequest>| async move {
832            handlers::get_namespace_proof(&state, query).await.map(Json)
833        };
834
835    let get_incorrect_encoding_proof = |State(state): State<S>,
836                                        SendQuery(query): SendQuery<
837        GetIncorrectEncodingProofRequest,
838    >| async move {
839        handlers::get_incorrect_encoding_proof(&state, query)
840            .await
841            .map(Json)
842    };
843
844    let router = ApiRouter::new()
845        .api_route(
846            routes::v2::REWARD_CLAIM_INPUT_ROUTE.http,
847            get_with(get_reward_claim_input, |op| {
848                op.description(routes::v2::REWARD_CLAIM_INPUT_ROUTE.description)
849                    .tag(routes::v2::REWARD_CLAIM_INPUT_ROUTE.tag)
850            }),
851        )
852        .api_route(
853            routes::v2::REWARD_BALANCE_ROUTE.http,
854            get_with(get_reward_balance, |op| {
855                op.description(routes::v2::REWARD_BALANCE_ROUTE.description)
856                    .tag(routes::v2::REWARD_BALANCE_ROUTE.tag)
857            }),
858        )
859        .api_route(
860            routes::v2::REWARD_ACCOUNT_PROOF_ROUTE.http,
861            get_with(get_reward_account_proof, |op| {
862                op.description(routes::v2::REWARD_ACCOUNT_PROOF_ROUTE.description)
863                    .tag(routes::v2::REWARD_ACCOUNT_PROOF_ROUTE.tag)
864            }),
865        )
866        .api_route(
867            routes::v2::REWARD_BALANCES_ROUTE.http,
868            get_with(get_reward_balances, |op| {
869                op.description(routes::v2::REWARD_BALANCES_ROUTE.description)
870                    .tag(routes::v2::REWARD_BALANCES_ROUTE.tag)
871            }),
872        )
873        .api_route(
874            routes::v2::REWARD_MERKLE_TREE_V2_ROUTE.http,
875            get_with(get_reward_merkle_tree_v2, |op| {
876                op.description(routes::v2::REWARD_MERKLE_TREE_V2_ROUTE.description)
877                    .tag(routes::v2::REWARD_MERKLE_TREE_V2_ROUTE.tag)
878            }),
879        )
880        .api_route(
881            routes::v2::NAMESPACE_PROOF_ROUTE.http,
882            get_with(get_namespace_proof, |op| {
883                op.description(routes::v2::NAMESPACE_PROOF_ROUTE.description)
884                    .tag(routes::v2::NAMESPACE_PROOF_ROUTE.tag)
885            }),
886        )
887        .api_route(
888            routes::v2::INCORRECT_ENCODING_PROOF_ROUTE.http,
889            get_with(get_incorrect_encoding_proof, |op| {
890                op.description(routes::v2::INCORRECT_ENCODING_PROOF_ROUTE.description)
891                    .tag(routes::v2::INCORRECT_ENCODING_PROOF_ROUTE.tag)
892            }),
893        )
894        .api_route(
895            routes::v2::STATE_CERTIFICATE_ROUTE.http,
896            get_with(get_state_certificate, |op| {
897                op.description(routes::v2::STATE_CERTIFICATE_ROUTE.description)
898                    .tag(routes::v2::STATE_CERTIFICATE_ROUTE.tag)
899            }),
900        )
901        .api_route(
902            routes::v2::STAKE_TABLE_ROUTE.http,
903            get_with(get_stake_table, |op| {
904                op.description(routes::v2::STAKE_TABLE_ROUTE.description)
905                    .tag(routes::v2::STAKE_TABLE_ROUTE.tag)
906            }),
907        )
908        .finish_api(&mut api);
909
910    // Transform examples (array) to example (singular) for OpenAPI 3.0/Swagger compatibility
911    if let Some(ref mut components) = api.components {
912        let mut transform = schemars::transform::SetSingleExample::default();
913        for schema in components.schemas.values_mut() {
914            transform.transform(&mut schema.json_schema);
915        }
916    }
917
918    // Also transform path parameter schemas
919    if let Some(ref mut paths) = api.paths {
920        let mut transform = schemars::transform::SetSingleExample::default();
921        for path_item_ref in paths.paths.values_mut() {
922            if let aide::openapi::ReferenceOr::Item(path_item) = path_item_ref {
923                for operation in [
924                    &mut path_item.get,
925                    &mut path_item.post,
926                    &mut path_item.put,
927                    &mut path_item.delete,
928                    &mut path_item.patch,
929                ]
930                .into_iter()
931                .flatten()
932                {
933                    for param in &mut operation.parameters {
934                        if let aide::openapi::ReferenceOr::Item(param_item) = param {
935                            let parameter_data = match param_item {
936                                aide::openapi::Parameter::Query { parameter_data, .. } => {
937                                    parameter_data
938                                },
939                                aide::openapi::Parameter::Header { parameter_data, .. } => {
940                                    parameter_data
941                                },
942                                aide::openapi::Parameter::Path { parameter_data, .. } => {
943                                    parameter_data
944                                },
945                                aide::openapi::Parameter::Cookie { parameter_data, .. } => {
946                                    parameter_data
947                                },
948                            };
949                            if let aide::openapi::ParameterSchemaOrContent::Schema(ref mut schema) =
950                                parameter_data.format
951                            {
952                                transform.transform(&mut schema.json_schema);
953                            }
954                        }
955                    }
956                }
957            }
958        }
959    }
960
961    router
962        .route(routes::v2::OPENAPI_SPEC_ROUTE, get(serve_openapi_spec))
963        .route(routes::v2::SWAGGER_ROUTE, get(serve_swagger_ui))
964        .route("/v2/", get(serve_swagger_ui))
965        .route(
966            routes::v2::SCALAR_ROUTE,
967            get(Scalar::new(routes::v2::OPENAPI_SPEC_ROUTE)
968                .with_title("Espresso Node API v2")
969                .axum_handler()),
970        )
971        .route(
972            routes::v2::REDOC_ROUTE,
973            get(Redoc::new(routes::v2::OPENAPI_SPEC_ROUTE)
974                .with_title("Espresso Node API v2")
975                .axum_handler()),
976        )
977        .layer(Extension(api))
978        .with_state(state)
979}