Skip to main content

hotshot_types/
epoch_membership.rs

1use std::{
2    collections::{HashMap, HashSet, hash_map::Entry},
3    sync::Arc,
4};
5
6use alloy::primitives::U256;
7use async_broadcast::{InactiveReceiver, Sender, broadcast};
8use committable::Commitment;
9use either::Either;
10use hotshot_utils::{anytrace::*, *};
11use parking_lot::{Mutex, RwLock};
12use sha2::{Digest, Sha256};
13use tokio_util::sync::CancellationToken;
14use versions::DRB_FIX_VERSION;
15
16use crate::{
17    PeerConfig, PeerConnectInfo,
18    data::{BlockNumber, EpochNumber, Leaf2, ViewNumber},
19    drb::{DrbDifficultySelectorFn, DrbInput, DrbResult, compute_drb_result},
20    traits::{
21        block_contents::BlockHeader,
22        election::{Membership, MembershipSnapshot, NonEpochMembershipSnapshot},
23        node_implementation::NodeType,
24        signature_key::StakeTableEntryType,
25        storage::{
26            LoadDrbProgressFn, Storage, StoreDrbProgressFn, StoreDrbResultFn, load_drb_progress_fn,
27            store_drb_progress_fn, store_drb_result_fn,
28        },
29    },
30};
31
32type EpochMap<TYPES> = HashMap<EpochNumber, InactiveReceiver<Result<EpochMembership<TYPES>>>>;
33
34type DrbMap = HashSet<EpochNumber>;
35
36/// Cancellation tokens for in-flight DRB computations. When an
37/// external source supplies the DRB result for `epoch` (e.g. a decided leaf
38/// carrying `next_drb_result`), `supply_drb` fires the token so the local
39/// computation can stop early instead of grinding to completion.
40type DrbCancelMap = HashMap<EpochNumber, CancellationToken>;
41
42type EpochSender<TYPES> = (EpochNumber, Sender<Result<EpochMembership<TYPES>>>);
43
44/// The per-epoch snapshot type associated with `T::Membership`.
45type Snapshot<T> = <<T as NodeType>::Membership as Membership<T>>::Snapshot;
46
47/// The stake-table hash type associated with `T::Membership`'s per-epoch
48/// snapshot.
49type SnapshotStakeTableHash<T> = <Snapshot<T> as MembershipSnapshot<T>>::StakeTableHash;
50
51/// Struct to Coordinate membership catchup
52pub struct EpochMembershipCoordinator<TYPES: NodeType> {
53    membership: Arc<TYPES::Membership>,
54    catchup_map: Arc<Mutex<EpochMap<TYPES>>>,
55    drb_calculation_map: Arc<Mutex<DrbMap>>,
56    drb_cancel_map: Arc<Mutex<DrbCancelMap>>,
57    epoch_height: BlockNumber,
58    store_drb_progress_fn: StoreDrbProgressFn,
59    load_drb_progress_fn: LoadDrbProgressFn,
60    store_drb_result_fn: StoreDrbResultFn,
61    drb_difficulty_selector: Arc<RwLock<Option<DrbDifficultySelectorFn>>>,
62}
63
64impl<TYPES: NodeType> Clone for EpochMembershipCoordinator<TYPES> {
65    fn clone(&self) -> Self {
66        Self {
67            membership: Arc::clone(&self.membership),
68            catchup_map: Arc::clone(&self.catchup_map),
69            drb_calculation_map: Arc::clone(&self.drb_calculation_map),
70            drb_cancel_map: Arc::clone(&self.drb_cancel_map),
71            epoch_height: self.epoch_height,
72            store_drb_progress_fn: Arc::clone(&self.store_drb_progress_fn),
73            load_drb_progress_fn: Arc::clone(&self.load_drb_progress_fn),
74            store_drb_result_fn: self.store_drb_result_fn.clone(),
75            drb_difficulty_selector: Arc::clone(&self.drb_difficulty_selector),
76        }
77    }
78}
79
80impl<TYPES: NodeType> EpochMembershipCoordinator<TYPES> {
81    pub fn new<M, S, B>(membership: M, epoch_height: B, storage: &S) -> Self
82    where
83        M: Into<Arc<TYPES::Membership>>,
84        B: Into<BlockNumber>,
85        S: Storage<TYPES>,
86    {
87        Self {
88            membership: membership.into(),
89            catchup_map: Arc::default(),
90            drb_calculation_map: Arc::default(),
91            drb_cancel_map: Arc::default(),
92            epoch_height: epoch_height.into(),
93            store_drb_progress_fn: store_drb_progress_fn(storage.clone()),
94            load_drb_progress_fn: load_drb_progress_fn(storage.clone()),
95            store_drb_result_fn: store_drb_result_fn(storage.clone()),
96            drb_difficulty_selector: Arc::new(RwLock::new(None)),
97        }
98    }
99
100    pub fn epoch_height(&self) -> BlockNumber {
101        self.epoch_height
102    }
103
104    /// Get a reference to the membership
105    pub fn membership(&self) -> &TYPES::Membership {
106        &self.membership
107    }
108
109    /// Set the DRB difficulty selector
110    pub fn set_drb_difficulty_selector(&self, f: DrbDifficultySelectorFn) {
111        let mut drb_difficulty_selector_writer = self.drb_difficulty_selector.write();
112        *drb_difficulty_selector_writer = Some(f);
113    }
114
115    /// Get a Membership for a given Epoch, which is guaranteed to have a randomized stake
116    /// table for the given Epoch
117    pub fn membership_for_epoch(
118        &self,
119        maybe_epoch: Option<EpochNumber>,
120    ) -> Result<EpochMembership<TYPES>> {
121        let Some(epoch) = maybe_epoch else {
122            return Ok(EpochMembership {
123                coordinator: self.clone(),
124                snapshot: EpochMembershipSnapshot::NonEpoch(self.membership.non_epoch_snapshot()),
125            });
126        };
127        let Some(first_epoch) = self.membership.first_epoch() else {
128            return Err(error!(
129                "membership_for_epoch called with epoch {epoch:?} but first_epoch is unset"
130            ));
131        };
132        if epoch < first_epoch {
133            return Err(error!(
134                "membership_for_epoch called with epoch {epoch:?} before first_epoch {first_epoch}"
135            ));
136        }
137        if let Some(snapshot) = self.membership.snapshot(epoch)
138            && snapshot.has_drb()
139        {
140            return Ok(EpochMembership {
141                coordinator: self.clone(),
142                snapshot: EpochMembershipSnapshot::Epoch { epoch, snapshot },
143            });
144        }
145        let mut catchup_map = self.catchup_map.lock();
146        match catchup_map.entry(epoch) {
147            Entry::Occupied(_) => Err(warn!(
148                "Randomized stake table for epoch {epoch:?} unavailable. Catchup already in \
149                 progress"
150            )),
151            Entry::Vacant(e) => {
152                let coordinator = self.clone();
153                let (tx, rx) = broadcast(1);
154                e.insert(rx.deactivate());
155                drop(catchup_map);
156                spawn_catchup(coordinator, epoch, tx);
157                Err(warn!(
158                    "Randomized stake table for epoch {epoch:?} unavailable. Starting catchup"
159                ))
160            },
161        }
162    }
163
164    /// Get a Membership for a given Epoch, which is guaranteed to have a stake
165    /// table for the given Epoch
166    pub fn stake_table_for_epoch(&self, e: Option<EpochNumber>) -> Result<EpochMembership<TYPES>> {
167        let Some(epoch) = e else {
168            return Ok(EpochMembership {
169                coordinator: self.clone(),
170                snapshot: EpochMembershipSnapshot::NonEpoch(self.membership.non_epoch_snapshot()),
171            });
172        };
173        let Some(first_epoch) = self.membership.first_epoch() else {
174            return Err(error!(
175                "stake_table_for_epoch called with epoch {epoch:?} but first_epoch is unset"
176            ));
177        };
178        if epoch < first_epoch {
179            return Err(error!(
180                "stake_table_for_epoch called with epoch {epoch:?} before first_epoch \
181                 {first_epoch}"
182            ));
183        }
184        if let Some(snapshot) = self.membership.snapshot(epoch) {
185            return Ok(EpochMembership {
186                coordinator: self.clone(),
187                snapshot: EpochMembershipSnapshot::Epoch { epoch, snapshot },
188            });
189        }
190        let mut catchup_map = self.catchup_map.lock();
191        match catchup_map.entry(epoch) {
192            Entry::Occupied(_) => Err(warn!(
193                "Stake table for epoch {epoch:?} unavailable. Catchup already in progress"
194            )),
195            Entry::Vacant(e) => {
196                let coordinator = self.clone();
197                let (tx, rx) = broadcast(1);
198                e.insert(rx.deactivate());
199                drop(catchup_map);
200                spawn_catchup(coordinator, epoch, tx);
201
202                Err(warn!(
203                    "Stake table for epoch {epoch:?} unavailable. Starting catchup"
204                ))
205            },
206        }
207    }
208
209    /// Return the union of the stake table and DA committee for `epoch`,
210    /// keyed by signature key. Each entry's `Option<PeerConnectInfo>`
211    /// reflects whether the peer has connection info registered.
212    ///
213    /// Returns `None` if the stake table for `epoch` is unavailable
214    /// (e.g. catchup is still in progress).
215    pub fn epoch_peers(
216        &self,
217        e: Option<EpochNumber>,
218    ) -> Option<HashMap<TYPES::SignatureKey, Option<PeerConnectInfo>>> {
219        let membership = self.stake_table_for_epoch(e).ok()?;
220        Some(if let Some(snap) = membership.snapshot() {
221            snap.stake_table()
222                .chain(snap.da_stake_table())
223                .map(|m| (m.stake_table_entry.public_key(), m.connect_info.clone()))
224                .collect()
225        } else {
226            let snap = membership.non_epoch_snapshot()?;
227            snap.stake_table()
228                .chain(snap.da_stake_table())
229                .map(|m| (m.stake_table_entry.public_key(), m.connect_info.clone()))
230                .collect()
231        })
232    }
233
234    /// Collect the union of `epoch-1`, `epoch`, and `epoch+1` stake tables
235    /// (each merged with its DA committee) as a flat map of peers to dial.
236    ///
237    /// Newest-wins ordering for `connect_info`: next overrides curr overrides
238    /// prev. Entries with no `connect_info` are filtered out.
239    ///
240    /// Used to seed networks like cliquenet with the same window
241    /// `on_epoch_change` would build for `epoch`.
242    pub fn window_peers(&self, e: EpochNumber) -> HashMap<TYPES::SignatureKey, PeerConnectInfo> {
243        let curr = self.epoch_peers(Some(e)).unwrap_or_default();
244        let prev = if *e > 0 {
245            self.epoch_peers(Some(e - 1)).unwrap_or_default()
246        } else {
247            HashMap::new()
248        };
249        let next = self.epoch_peers(Some(e + 1)).unwrap_or_default();
250
251        // Newest-wins merge: start from prev, overlay curr and next.
252        let mut merged: HashMap<TYPES::SignatureKey, Option<PeerConnectInfo>> = prev;
253        for (k, v) in curr.into_iter().chain(next) {
254            merged.insert(k, v);
255        }
256
257        merged
258            .into_iter()
259            .filter_map(|(k, v)| v.map(|info| (k, info)))
260            .collect()
261    }
262
263    /// Catches the membership up to the epoch passed as an argument.
264    /// To do this, try to get the stake table for the epoch containing this
265    /// epoch's root and the stake table for the epoch containing this epoch's
266    /// drb result. If they do not exist, then go one by one back until we
267    /// find a stake table.
268    ///
269    /// If there is another catchup in progress this will not duplicate efforts
270    /// e.g. if we start with only the first epoch stake table and call catchup
271    /// for epoch 10, then call catchup for epoch 20 the first caller will
272    /// actually do the work for to catchup to epoch 10 then the second caller
273    /// will continue catching up to epoch 20
274    //
275    // Clippy claims "this `MutexGuard` is held across an await point", however
276    // the guard is explicitly dropped before. See also:
277    // https://github.com/rust-lang/rust-clippy/issues/6446
278    //
279    // Even more annoying is that the warning can only be disabled on function
280    // level, instead of putting this attribute on the expression, see
281    // https://github.com/rust-lang/rust-clippy/issues/9047.
282    #[allow(clippy::await_holding_lock)]
283    async fn catchup(self, epoch: EpochNumber, epoch_tx: Sender<Result<EpochMembership<TYPES>>>) {
284        // We need to fetch the requested epoch, that's for sure
285        let mut fetch_epochs = vec![];
286
287        let mut try_epoch = EpochNumber::new(epoch.saturating_sub(1));
288        let maybe_first_epoch = self.membership.first_epoch();
289        let Some(first_epoch) = maybe_first_epoch else {
290            let err = anytrace::error!(
291                "We got a catchup request for epoch {epoch:?} but the first epoch is not set"
292            );
293            self.catchup_cleanup(epoch, epoch_tx.clone(), fetch_epochs, err);
294            return;
295        };
296
297        // First figure out which epochs we need to fetch
298        loop {
299            let has_stake_table = self.membership.snapshot(try_epoch).is_some();
300            if has_stake_table {
301                // We have this stake table but we need to make sure we have the
302                // epoch root of the requested epoch
303                if try_epoch <= EpochNumber::new(epoch.saturating_sub(2)) {
304                    break;
305                }
306                try_epoch = EpochNumber::new(try_epoch.saturating_sub(1));
307            } else {
308                if try_epoch <= first_epoch + 1 {
309                    let err = anytrace::error!(
310                        "We are trying to catchup to an epoch lower than the second epoch! This \
311                         means the initial stake table is missing!"
312                    );
313                    self.catchup_cleanup(epoch, epoch_tx.clone(), fetch_epochs, err);
314                    return;
315                }
316                // Lock the catchup map
317                let mut map_lock = self.catchup_map.lock();
318                match map_lock
319                    .get(&try_epoch)
320                    .map(InactiveReceiver::activate_cloned)
321                {
322                    Some(mut rx) => {
323                        // Somebody else is already fetching this epoch, drop
324                        // the lock and wait for them to finish
325                        drop(map_lock);
326                        if let Ok(Ok(_)) = rx.recv_direct().await {
327                            break;
328                        };
329                        // If we didn't receive the epoch then we need to try again
330                    },
331                    _ => {
332                        // Nobody else is fetching this epoch. We need to do it.
333                        // Put it in the map and move on to the next epoch
334                        let (mut tx, rx) = broadcast(1);
335                        tx.set_overflow(true);
336                        map_lock.insert(try_epoch, rx.deactivate());
337                        drop(map_lock);
338                        fetch_epochs.push((try_epoch, tx));
339                        try_epoch = EpochNumber::new(try_epoch.saturating_sub(1));
340                    },
341                }
342            };
343        }
344
345        let epochs = fetch_epochs.iter().map(|(e, _)| e).collect::<Vec<_>>();
346        tracing::warn!("Fetching stake tables for epochs: {epochs:?}");
347
348        // Iterate through the epochs we need to fetch in reverse, i.e. from the oldest to the newest
349        while let Some((current_fetch_epoch, tx)) = fetch_epochs.pop() {
350            match self.fetch_stake_table(current_fetch_epoch).await {
351                Ok(_) => {},
352                Err(err) => {
353                    fetch_epochs.push((current_fetch_epoch, tx));
354                    self.catchup_cleanup(epoch, epoch_tx, fetch_epochs, err);
355                    return;
356                },
357            };
358
359            // Signal the other tasks about the success. `fetch_stake_table`
360            // returned `Ok`, so a snapshot must be present. If it isn't,
361            // treat that as a catchup failure: push the in-flight epoch
362            // back and run cleanup so waiters get notified.
363            let Some(snapshot) = self.membership.snapshot(current_fetch_epoch) else {
364                let err = anytrace::error!(
365                    "snapshot for epoch {current_fetch_epoch} unavailable after fetch_stake_table"
366                );
367                fetch_epochs.push((current_fetch_epoch, tx));
368                self.catchup_cleanup(epoch, epoch_tx, fetch_epochs, err);
369                return;
370            };
371            let mem = EpochMembership {
372                coordinator: self.clone(),
373                snapshot: EpochMembershipSnapshot::Epoch {
374                    epoch: current_fetch_epoch,
375                    snapshot,
376                },
377            };
378            if let Ok(Some(res)) = tx.try_broadcast(Ok(mem)) {
379                tracing::warn!(
380                    "The catchup channel for epoch {} was overflown, dropped message {:?}",
381                    current_fetch_epoch,
382                    res.map(|em| em.epoch())
383                );
384            }
385
386            // Remove the epoch from the catchup map to indicate that the catchup is complete
387            self.catchup_map.lock().remove(&current_fetch_epoch);
388        }
389
390        let root_leaf = match self.fetch_stake_table(epoch).await {
391            Ok(root_leaf) => root_leaf,
392            Err(err) => {
393                tracing::error!("Failed to fetch stake table for epoch {epoch:?}: {err:?}");
394                self.catchup_cleanup(epoch, epoch_tx.clone(), fetch_epochs, err);
395                return;
396            },
397        };
398
399        match self.membership.get_epoch_drb(epoch).await {
400            Ok(drb_result) => {
401                tracing::warn!(
402                    ?drb_result,
403                    "DRB result for epoch {epoch:?} retrieved from peers. Updating membership."
404                );
405                self.membership.add_drb_result(epoch, drb_result);
406            },
407            Err(err) => {
408                tracing::warn!(
409                    "Recalculating missing DRB result for epoch {}. Catchup failed with error: {}",
410                    epoch,
411                    err
412                );
413
414                let result = self.compute_drb_result(epoch, root_leaf).await;
415
416                log!(result);
417
418                if let Err(err) = result {
419                    self.catchup_cleanup(epoch, epoch_tx.clone(), fetch_epochs, err);
420                    return;
421                }
422            },
423        };
424
425        // Signal the other tasks about the success. As above, the snapshot
426        // must be present at this point — if not, treat as a catchup failure.
427        let Some(snapshot) = self.membership.snapshot(epoch) else {
428            let err = anytrace::error!(
429                "snapshot for epoch {epoch} unavailable after fetch_stake_table + DRB"
430            );
431            self.catchup_cleanup(epoch, epoch_tx.clone(), fetch_epochs, err);
432            return;
433        };
434        let mem = EpochMembership {
435            coordinator: self.clone(),
436            snapshot: EpochMembershipSnapshot::Epoch { epoch, snapshot },
437        };
438        if let Ok(Some(res)) = epoch_tx.try_broadcast(Ok(mem)) {
439            tracing::warn!(
440                "The catchup channel for epoch {} was overflown, dropped message {:?}",
441                epoch,
442                res.map(|em| em.epoch())
443            );
444        }
445
446        // Remove the epoch from the catchup map to indicate that the catchup is complete
447        self.catchup_map.lock().remove(&epoch);
448    }
449
450    /// Get the stake table for `epoch`, blocking on catchup if necessary.
451    ///
452    /// Unlike `stake_table_for_epoch`, this returns the result rather than
453    /// kicking off catchup and immediately returning an error. Used at startup
454    /// to drive the existing catchup chain synchronously before consensus is
455    /// running.
456    pub async fn wait_for_stake_table(&self, epoch: EpochNumber) -> Result<EpochMembership<TYPES>> {
457        match self.stake_table_for_epoch(Some(epoch)) {
458            Ok(mem) => Ok(mem),
459            Err(_) => self.wait_for_catchup(epoch).await,
460        }
461    }
462
463    /// Call this method if you think catchup is in progress for a given epoch
464    /// and you want to wait for it to finish and get the stake table.
465    /// If it's not, it will try to return the stake table if already available.
466    /// Returns an error if the catchup failed or the catchup is not in progress
467    /// and the stake table is not available.
468    pub async fn wait_for_catchup(&self, epoch: EpochNumber) -> Result<EpochMembership<TYPES>> {
469        let maybe_receiver = self
470            .catchup_map
471            .lock()
472            .get(&epoch)
473            .map(InactiveReceiver::activate_cloned);
474        let Some(mut rx) = maybe_receiver else {
475            // There is no catchup in progress, maybe the epoch is already finalized
476            if let Some(snapshot) = self.membership.snapshot(epoch) {
477                return Ok(EpochMembership {
478                    coordinator: self.clone(),
479                    snapshot: EpochMembershipSnapshot::Epoch { epoch, snapshot },
480                });
481            }
482            return Err(anytrace::error!(
483                "No catchup in progress for epoch {epoch} and we don't have a stake table for it"
484            ));
485        };
486        let Ok(Ok(mem)) = rx.recv_direct().await else {
487            return Err(anytrace::error!("Catchup for epoch {epoch} failed"));
488        };
489        Ok(mem)
490    }
491
492    /// Clean up after a failed catchup attempt.
493    ///
494    /// This method is called when a catchup attempt fails. It cleans up the state of the
495    /// `EpochMembershipCoordinator` by removing the failed epochs from the
496    /// `catchup_map` and broadcasting the error to any tasks that are waiting for the
497    /// catchup to complete.
498    fn catchup_cleanup(
499        &self,
500        req_epoch: EpochNumber,
501        epoch_tx: Sender<Result<EpochMembership<TYPES>>>,
502        mut cancel_epochs: Vec<EpochSender<TYPES>>,
503        err: Error,
504    ) {
505        // Cleanup in case of error
506        cancel_epochs.push((req_epoch, epoch_tx));
507
508        tracing::error!(
509            "catchup for epoch {req_epoch:?} failed: {err:?}. Canceling catchup for epochs: {:?}",
510            cancel_epochs.iter().map(|(e, _)| e).collect::<Vec<_>>()
511        );
512
513        {
514            let mut map_lock = self.catchup_map.lock();
515            for (epoch, _) in cancel_epochs.iter() {
516                // Remove the failed epochs from the catchup map
517                map_lock.remove(epoch);
518            }
519        }
520
521        for (cancel_epoch, tx) in cancel_epochs {
522            // Signal the other tasks about the failures
523            if let Ok(Some(res)) = tx.try_broadcast(Err(err.clone())) {
524                tracing::warn!(
525                    "The catchup channel for epoch {} was overflown during cleanup, dropped \
526                     message {:?}",
527                    cancel_epoch,
528                    res.map(|em| em.epoch())
529                );
530            }
531        }
532    }
533
534    /// A helper method to the `catchup` method.
535    ///
536    /// It tries to fetch the requested stake table from the root epoch,
537    /// and updates the membership accordingly.
538    ///
539    /// # Arguments
540    ///
541    /// * `epoch` - The epoch for which to fetch the stake table.
542    ///
543    /// # Returns
544    ///
545    /// * `Ok(Leaf2<TYPES>)` containing the epoch root leaf if successful.
546    /// * `Err(Error)` if the root membership or root leaf cannot be found, or if
547    ///   updating the membership fails.
548    async fn fetch_stake_table(&self, epoch: EpochNumber) -> Result<Leaf2<TYPES>> {
549        let root_epoch = EpochNumber::new(epoch.saturating_sub(2));
550        let Ok(root_membership) = self.stake_table_for_epoch(Some(root_epoch)) else {
551            return Err(anytrace::error!(
552                "We tried to fetch stake table for epoch {epoch:?} but we don't have its root \
553                 epoch {root_epoch:?}. This should not happen"
554            ));
555        };
556
557        // Get the epoch root headers and update our membership with them, finally sync them
558        // Verification of the root is handled in get_epoch_root_and_drb
559        let Ok(root_leaf) = root_membership.get_epoch_root().await else {
560            return Err(anytrace::error!(
561                "get epoch root leaf failed for epoch {root_epoch:?}"
562            ));
563        };
564
565        self.membership
566            .add_epoch_root(root_leaf.block_header().clone())
567            .await
568            .map_err(|e| {
569                anytrace::error!("Failed to add epoch root for epoch {epoch:?} to membership: {e}")
570            })?;
571
572        Ok(root_leaf)
573    }
574
575    pub async fn compute_drb_result(
576        &self,
577        epoch: EpochNumber,
578        root_leaf: Leaf2<TYPES>,
579    ) -> Result<DrbResult> {
580        let cancel_token = {
581            let mut drb_calculation_map_lock = self.drb_calculation_map.lock();
582
583            if drb_calculation_map_lock.contains(&epoch) {
584                return Err(anytrace::debug!(
585                    "DRB calculation for epoch {} already in progress",
586                    epoch
587                ));
588            }
589            drb_calculation_map_lock.insert(epoch);
590
591            let token = CancellationToken::new();
592            self.drb_cancel_map.lock().insert(epoch, token.clone());
593            token
594        };
595
596        let Ok(drb_seed_input_vec) = bincode::serialize(&root_leaf.justify_qc().signatures) else {
597            self.clear_drb_state(epoch);
598            return Err(anytrace::error!(
599                "Failed to serialize the QC signature for leaf {root_leaf:?}"
600            ));
601        };
602
603        let Some(drb_difficulty_selector) = self.drb_difficulty_selector.read().clone() else {
604            self.clear_drb_state(epoch);
605            return Err(anytrace::error!(
606                "The DRB difficulty selector is missing from the epoch membership coordinator. \
607                 This node will not be able to spawn any DRB calculation tasks from catchup."
608            ));
609        };
610
611        let drb_difficulty = drb_difficulty_selector(root_leaf.block_header().version()).await;
612
613        let mut drb_seed_input = [0u8; 32];
614
615        if root_leaf.block_header().version() >= DRB_FIX_VERSION {
616            drb_seed_input = Sha256::digest(&drb_seed_input_vec).into();
617        } else {
618            let len = drb_seed_input_vec.len().min(32);
619            drb_seed_input[..len].copy_from_slice(&drb_seed_input_vec[..len]);
620        }
621
622        let drb_input = DrbInput {
623            epoch: *epoch,
624            iteration: 0,
625            value: drb_seed_input,
626            difficulty_level: drb_difficulty,
627        };
628
629        let store_drb_progress_fn = self.store_drb_progress_fn.clone();
630        let load_drb_progress_fn = self.load_drb_progress_fn.clone();
631
632        // Race the local computation against the cancellation token. If the
633        // token fires, an external source has already added the DRB to
634        // membership, so read it back rather than waiting for the local hash
635        // loop to finish.
636        let drb = tokio::select! {
637            drb = compute_drb_result(drb_input, store_drb_progress_fn, load_drb_progress_fn) => {
638                drb
639            },
640            () = cancel_token.cancelled() => {
641                tracing::info!(
642                    "DRB calculation for epoch {epoch} cancelled by external supplier"
643                );
644                self.clear_drb_state(epoch);
645                return self.membership.get_epoch_drb(epoch).await.map_err(|e| {
646                    anytrace::error!(
647                        "DRB calculation for epoch {epoch} was cancelled but the externally \
648                         supplied result is no longer available: {e}"
649                    )
650                });
651            },
652        };
653
654        self.clear_drb_state(epoch);
655
656        tracing::info!("Writing drb result from catchup to storage for epoch {epoch}: {drb:?}");
657        if let Err(e) = (self.store_drb_result_fn)(epoch, drb).await {
658            tracing::warn!("Failed to add drb result to storage: {e}");
659        }
660        self.membership.add_drb_result(epoch, drb);
661
662        Ok(drb)
663    }
664
665    /// Supply a DRB result obtained from an external source (e.g. a decided
666    /// leaf carrying `next_drb_result`). Adds the result to membership,
667    /// persists it to storage, and cancels any in-flight local computation
668    /// for `epoch`.
669    ///
670    /// If the stake table for `epoch` has not yet been loaded (e.g. the async
671    /// catchup that registers it is still in flight), this logs an error and
672    /// returns; the in-flight catchup will compute the DRB itself once it
673    /// completes.
674    pub fn supply_drb(&self, epoch: EpochNumber, drb: DrbResult) {
675        if self.membership.snapshot(epoch).is_none() {
676            tracing::error!(
677                "supply_drb called for epoch {epoch} but stake table not yet loaded; dropping \
678                 externally-supplied DRB and relying on in-flight catchup"
679            );
680            return;
681        }
682        self.membership.add_drb_result(epoch, drb);
683        let maybe_token = self.drb_cancel_map.lock().remove(&epoch);
684        if let Some(token) = maybe_token {
685            token.cancel();
686        }
687        let store_drb_result_fn = self.store_drb_result_fn.clone();
688        tokio::spawn(async move {
689            tracing::info!(
690                "Writing externally supplied drb result to storage for epoch {epoch}: {drb:?}"
691            );
692            if let Err(e) = store_drb_result_fn(epoch, drb).await {
693                tracing::warn!("Failed to add externally supplied drb result to storage: {e}");
694            }
695        });
696    }
697
698    /// Remove per-epoch DRB bookkeeping after a computation finishes or is
699    /// cancelled. Safe to call multiple times.
700    fn clear_drb_state(&self, epoch: EpochNumber) {
701        self.drb_calculation_map.lock().remove(&epoch);
702        self.drb_cancel_map.lock().remove(&epoch);
703    }
704}
705
706fn spawn_catchup<T: NodeType>(
707    coordinator: EpochMembershipCoordinator<T>,
708    epoch: EpochNumber,
709    epoch_tx: Sender<Result<EpochMembership<T>>>,
710) {
711    tokio::spawn(async move {
712        coordinator.clone().catchup(epoch, epoch_tx).await;
713    });
714}
715
716/// Wrapper around a membership that holds a captured snapshot for a given
717/// epoch (or the pre-epoch state). All accessors observe one consistent
718/// view because the snapshot is held inline.
719pub struct EpochMembership<TYPES: NodeType> {
720    /// The captured snapshot, either per-epoch or pre-epoch.
721    snapshot: EpochMembershipSnapshot<TYPES>,
722    /// Underlying coordinator, retained so navigation methods like
723    /// `next_epoch` can construct fresh snapshots.
724    pub coordinator: EpochMembershipCoordinator<TYPES>,
725}
726
727enum EpochMembershipSnapshot<TYPES: NodeType> {
728    Epoch {
729        epoch: EpochNumber,
730        snapshot: <TYPES::Membership as Membership<TYPES>>::Snapshot,
731    },
732    NonEpoch(<TYPES::Membership as Membership<TYPES>>::NonEpochSnapshot),
733}
734
735impl<TYPES: NodeType> Clone for EpochMembershipSnapshot<TYPES> {
736    fn clone(&self) -> Self {
737        match self {
738            Self::Epoch { epoch, snapshot } => Self::Epoch {
739                epoch: *epoch,
740                snapshot: snapshot.clone(),
741            },
742            Self::NonEpoch(s) => Self::NonEpoch(s.clone()),
743        }
744    }
745}
746
747impl<TYPES: NodeType> Clone for EpochMembership<TYPES> {
748    fn clone(&self) -> Self {
749        Self {
750            coordinator: self.coordinator.clone(),
751            snapshot: self.snapshot.clone(),
752        }
753    }
754}
755
756impl<TYPES: NodeType> EpochMembership<TYPES> {
757    pub fn epoch(&self) -> Option<EpochNumber> {
758        match &self.snapshot {
759            EpochMembershipSnapshot::Epoch { epoch, .. } => Some(*epoch),
760            EpochMembershipSnapshot::NonEpoch(_) => None,
761        }
762    }
763
764    pub fn next_epoch(&self) -> Result<Self> {
765        let epoch = self
766            .epoch()
767            .ok_or_else(|| anytrace::error!("No next epoch because epoch is None"))?;
768        self.coordinator.membership_for_epoch(Some(epoch + 1))
769    }
770
771    pub fn next_epoch_stake_table(&self) -> Result<Self> {
772        let epoch = self
773            .epoch()
774            .ok_or_else(|| anytrace::error!("No next epoch because epoch is None"))?;
775        self.coordinator.stake_table_for_epoch(Some(epoch + 1))
776    }
777
778    pub fn get_new_epoch(&self, epoch: Option<EpochNumber>) -> Result<Self> {
779        self.coordinator.membership_for_epoch(epoch)
780    }
781
782    async fn get_epoch_root(&self) -> anyhow::Result<Leaf2<TYPES>> {
783        let Some(epoch) = self.epoch() else {
784            anyhow::bail!("Cannot get root for None epoch");
785        };
786        let leaf = self.coordinator.membership.get_epoch_root(epoch).await?;
787        Ok(leaf)
788    }
789
790    pub async fn get_epoch_drb(&self) -> Result<DrbResult> {
791        let Some(epoch) = self.epoch() else {
792            return Err(anytrace::warn!("Cannot get drb for None epoch"));
793        };
794        self.coordinator
795            .membership
796            .get_epoch_drb(epoch)
797            .await
798            .wrap()
799    }
800
801    /// Borrow the per-epoch snapshot, or `None` for the pre-epoch case.
802    pub fn snapshot(&self) -> Option<&<TYPES::Membership as Membership<TYPES>>::Snapshot> {
803        match &self.snapshot {
804            EpochMembershipSnapshot::Epoch { snapshot, .. } => Some(snapshot),
805            EpochMembershipSnapshot::NonEpoch(_) => None,
806        }
807    }
808
809    /// Borrow the pre-epoch snapshot, or `None` if this is a per-epoch
810    /// membership.
811    pub fn non_epoch_snapshot(
812        &self,
813    ) -> Option<&<TYPES::Membership as Membership<TYPES>>::NonEpochSnapshot> {
814        match &self.snapshot {
815            EpochMembershipSnapshot::NonEpoch(s) => Some(s),
816            EpochMembershipSnapshot::Epoch { .. } => None,
817        }
818    }
819
820    /// Add the DRB result for this epoch to the membership.
821    pub fn add_drb_result(&self, drb_result: DrbResult) {
822        if let Some(epoch) = self.epoch() {
823            self.coordinator
824                .membership
825                .add_drb_result(epoch, drb_result);
826        }
827    }
828
829    // ---------------------------------------------------------------------
830    // Single-call convenience accessors. Each delegates to whichever
831    // snapshot was captured at construction time, so a single accessor
832    // call observes one consistent view. For *sequences* of related reads
833    // that must observe the same view, take a snapshot via
834    // [`Self::snapshot`] / [`Self::non_epoch_snapshot`] and call methods
835    // on it directly.
836    // ---------------------------------------------------------------------
837
838    pub fn stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<TYPES>> + Send {
839        match &self.snapshot {
840            EpochMembershipSnapshot::Epoch { snapshot, .. } => Either::Left(snapshot.stake_table()),
841            EpochMembershipSnapshot::NonEpoch(s) => Either::Right(s.stake_table()),
842        }
843    }
844
845    pub fn da_stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<TYPES>> + Send {
846        match &self.snapshot {
847            EpochMembershipSnapshot::Epoch { snapshot, .. } => {
848                Either::Left(snapshot.da_stake_table())
849            },
850            EpochMembershipSnapshot::NonEpoch(s) => Either::Right(s.da_stake_table()),
851        }
852    }
853
854    pub fn committee_members(
855        &self,
856        view: ViewNumber,
857    ) -> impl ExactSizeIterator<Item = &TYPES::SignatureKey> + Send {
858        match &self.snapshot {
859            EpochMembershipSnapshot::Epoch { snapshot, .. } => {
860                Either::Left(snapshot.committee_members(view))
861            },
862            EpochMembershipSnapshot::NonEpoch(s) => Either::Right(s.committee_members(view)),
863        }
864    }
865
866    pub fn da_committee_members(
867        &self,
868        view: ViewNumber,
869    ) -> impl ExactSizeIterator<Item = &TYPES::SignatureKey> + Send {
870        match &self.snapshot {
871            EpochMembershipSnapshot::Epoch { snapshot, .. } => {
872                Either::Left(snapshot.da_committee_members(view))
873            },
874            EpochMembershipSnapshot::NonEpoch(s) => Either::Right(s.da_committee_members(view)),
875        }
876    }
877
878    pub fn stake(&self, key: &TYPES::SignatureKey) -> Option<PeerConfig<TYPES>> {
879        match &self.snapshot {
880            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.stake(key),
881            EpochMembershipSnapshot::NonEpoch(s) => s.stake(key),
882        }
883    }
884
885    pub fn da_stake(&self, key: &TYPES::SignatureKey) -> Option<PeerConfig<TYPES>> {
886        match &self.snapshot {
887            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.da_stake(key),
888            EpochMembershipSnapshot::NonEpoch(s) => s.da_stake(key),
889        }
890    }
891
892    pub fn has_stake(&self, key: &TYPES::SignatureKey) -> bool {
893        match &self.snapshot {
894            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.has_stake(key),
895            EpochMembershipSnapshot::NonEpoch(s) => s.has_stake(key),
896        }
897    }
898
899    pub fn has_da_stake(&self, key: &TYPES::SignatureKey) -> bool {
900        match &self.snapshot {
901            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.has_da_stake(key),
902            EpochMembershipSnapshot::NonEpoch(s) => s.has_da_stake(key),
903        }
904    }
905
906    /// The leader for `view`, returning a HotShot-internal error type.
907    ///
908    /// # Errors
909    ///
910    /// Returns an error if the leader cannot be calculated.
911    pub fn leader(&self, view: ViewNumber) -> Result<TYPES::SignatureKey> {
912        match &self.snapshot {
913            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.leader(view),
914            EpochMembershipSnapshot::NonEpoch(s) => s.leader(view),
915        }
916    }
917
918    /// The leader for `view`, returning the membership-impl error type.
919    ///
920    /// # Errors
921    ///
922    /// Returns the membership-impl error if the leader cannot be calculated.
923    pub fn lookup_leader(
924        &self,
925        view: ViewNumber,
926    ) -> std::result::Result<
927        TYPES::SignatureKey,
928        <<TYPES as NodeType>::Membership as Membership<TYPES>>::Error,
929    > {
930        match &self.snapshot {
931            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.lookup_leader(view),
932            EpochMembershipSnapshot::NonEpoch(s) => s.lookup_leader(view),
933        }
934    }
935
936    pub fn total_nodes(&self) -> usize {
937        match &self.snapshot {
938            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.total_nodes(),
939            EpochMembershipSnapshot::NonEpoch(s) => s.total_nodes(),
940        }
941    }
942
943    pub fn da_total_nodes(&self) -> usize {
944        match &self.snapshot {
945            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.da_total_nodes(),
946            EpochMembershipSnapshot::NonEpoch(s) => s.da_total_nodes(),
947        }
948    }
949
950    pub fn success_threshold(&self) -> U256 {
951        match &self.snapshot {
952            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.success_threshold(),
953            EpochMembershipSnapshot::NonEpoch(s) => s.success_threshold(),
954        }
955    }
956
957    pub fn da_success_threshold(&self) -> U256 {
958        match &self.snapshot {
959            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.da_success_threshold(),
960            EpochMembershipSnapshot::NonEpoch(s) => s.da_success_threshold(),
961        }
962    }
963
964    pub fn failure_threshold(&self) -> U256 {
965        match &self.snapshot {
966            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.failure_threshold(),
967            EpochMembershipSnapshot::NonEpoch(s) => s.failure_threshold(),
968        }
969    }
970
971    pub fn upgrade_threshold(&self) -> U256 {
972        match &self.snapshot {
973            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.upgrade_threshold(),
974            EpochMembershipSnapshot::NonEpoch(s) => s.upgrade_threshold(),
975        }
976    }
977
978    pub fn stake_table_hash(&self) -> Option<Commitment<SnapshotStakeTableHash<TYPES>>> {
979        match &self.snapshot {
980            EpochMembershipSnapshot::Epoch { snapshot, .. } => snapshot.stake_table_hash(),
981            EpochMembershipSnapshot::NonEpoch(_) => None,
982        }
983    }
984}