Skip to main content

hotshot_example_types/membership/
strict_membership.rs

1use std::{
2    collections::HashSet,
3    fmt::{self, Debug},
4    sync::Arc,
5};
6
7use alloy::primitives::U256;
8use anyhow::anyhow;
9use async_broadcast::Receiver;
10use async_lock::RwLock as AsyncRwLock;
11use hotshot_types::{
12    PeerConfig,
13    data::{BlockNumber, EpochNumber, Leaf2, ViewNumber},
14    drb::DrbResult,
15    event::Event,
16    traits::{
17        block_contents::BlockHeader,
18        election::{Membership, MembershipSnapshot, NoStakeTableHash, NonEpochMembershipSnapshot},
19        leaf_fetcher_network::LeafFetcherNetwork,
20        node_implementation::NodeType,
21        signature_key::StakeTableEntryType,
22    },
23    utils::{epoch_from_block_number, root_block_in_epoch, transition_block_for_epoch},
24};
25use parking_lot::RwLock;
26
27use crate::{
28    membership::{TestableMembership, fetcher::Leaf2Fetcher, stake_table::TestStakeTable},
29    storage_types::TestStorage,
30};
31
32#[derive(Clone)]
33pub struct StrictMembership<T, S>
34where
35    T: NodeType,
36    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
37{
38    inner: Arc<RwLock<Inner<T, S>>>,
39    epoch_height: BlockNumber,
40}
41
42struct Inner<T: NodeType, S> {
43    table: S,
44    epochs: HashSet<EpochNumber>,
45    drbs: HashSet<EpochNumber>,
46    fetcher: Option<Arc<AsyncRwLock<Leaf2Fetcher<T>>>>,
47}
48
49impl<T, S> Debug for StrictMembership<T, S>
50where
51    T: NodeType,
52    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
53{
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
55        let inner = self.inner.read();
56        f.debug_struct("StrictMembership")
57            .field("table", &inner.table)
58            .field("epochs", &inner.epochs)
59            .field("drbs", &inner.drbs)
60            .finish()
61    }
62}
63
64impl<T, S> TestableMembership<T> for StrictMembership<T, S>
65where
66    T: NodeType,
67    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
68{
69    fn new(
70        quorum_members: Vec<PeerConfig<T>>,
71        da_members: Vec<PeerConfig<T>>,
72        _public_key: T::SignatureKey,
73        epoch_height: u64,
74    ) -> Self {
75        Self {
76            inner: Arc::new(RwLock::new(Inner {
77                table: TestStakeTable::new(
78                    quorum_members.into_iter().map(Into::into).collect(),
79                    da_members.into_iter().map(Into::into).collect(),
80                ),
81                epochs: HashSet::new(),
82                drbs: HashSet::new(),
83                fetcher: None,
84            })),
85            epoch_height: epoch_height.into(),
86        }
87    }
88
89    fn set_leaf_fetcher(
90        &self,
91        network: Arc<dyn LeafFetcherNetwork<T>>,
92        storage: TestStorage<T>,
93        public_key: T::SignatureKey,
94        channel: Receiver<Event<T>>,
95    ) {
96        let mut fetcher = Leaf2Fetcher::new(network, storage, public_key);
97        fetcher.set_external_channel(channel);
98        self.inner.write().fetcher = Some(Arc::new(AsyncRwLock::new(fetcher)));
99    }
100}
101
102impl<T: NodeType, S> Inner<T, S> {
103    fn assert_has_stake_table(&self, epoch: Option<EpochNumber>) {
104        let Some(epoch) = epoch else {
105            return;
106        };
107        assert!(
108            self.epochs.contains(&epoch),
109            "Failed stake table check for epoch {epoch}"
110        );
111    }
112}
113
114impl<T, S> Membership<T> for StrictMembership<T, S>
115where
116    T: NodeType,
117    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
118{
119    type Error = StrictMembershipError;
120    type Snapshot = StrictEpochSnapshot<T, S>;
121    type NonEpochSnapshot = StrictNonEpochSnapshot<T, S>;
122
123    fn snapshot(&self, epoch: EpochNumber) -> Option<Self::Snapshot> {
124        let inner = self.inner.read();
125        if !inner.epochs.contains(&epoch) {
126            return None;
127        }
128        let has_drb = inner.drbs.contains(&epoch);
129        let first_epoch = inner.table.first_epoch().map(EpochNumber::new);
130        Some(StrictEpochSnapshot::build(
131            epoch,
132            first_epoch,
133            has_drb,
134            inner.table.clone(),
135        ))
136    }
137
138    fn non_epoch_snapshot(&self) -> Self::NonEpochSnapshot {
139        StrictNonEpochSnapshot::build(self.inner.read().table.clone())
140    }
141
142    fn add_drb_result(&self, e: EpochNumber, drb: DrbResult) {
143        let mut inner = self.inner.write();
144        inner.assert_has_stake_table(Some(e));
145        inner.drbs.insert(e);
146        inner.table.add_drb_result(*e, drb);
147    }
148
149    fn first_epoch(&self) -> Option<EpochNumber> {
150        self.inner.read().table.first_epoch().map(EpochNumber::new)
151    }
152
153    fn set_first_epoch(&self, e: EpochNumber, initial_drb_result: DrbResult) {
154        let mut inner = self.inner.write();
155        inner.epochs.insert(e);
156        inner.epochs.insert(e + 1);
157
158        inner.drbs.insert(e);
159        inner.drbs.insert(e + 1);
160
161        inner.table.set_first_epoch(*e, initial_drb_result);
162    }
163
164    async fn add_epoch_root(&self, hdr: T::BlockHeader) -> Result<(), Self::Error> {
165        let epoch = epoch_from_block_number(hdr.block_number(), *self.epoch_height) + 2;
166
167        let mut inner = self.inner.write();
168        inner.epochs.insert(EpochNumber::new(epoch));
169        inner.table.add_epoch_root(epoch);
170
171        Ok(())
172    }
173
174    async fn get_epoch_root(&self, e: EpochNumber) -> Result<Leaf2<T>, Self::Error> {
175        let block_height = root_block_in_epoch(*e, *self.epoch_height);
176
177        let (stake_table, fetcher) = {
178            let inner = self.inner.read();
179            let table = inner.table.stake_table(Some(*e));
180            let fetcher = inner
181                .fetcher
182                .clone()
183                .expect("get_epoch_root called before set_leaf_fetcher_network");
184            (table, fetcher)
185        };
186
187        for node in stake_table {
188            if let Ok(leaf) = fetcher
189                .read()
190                .await
191                .fetch_leaf(block_height, node.signature_key)
192                .await
193            {
194                return Ok(leaf);
195            }
196        }
197
198        Err(anyhow!("Failed to fetch epoch root from any peer").into())
199    }
200
201    async fn get_epoch_drb(&self, e: EpochNumber) -> Result<DrbResult, Self::Error> {
202        let epoch_height = self.epoch_height;
203
204        let (epoch_drb, fetcher) = {
205            let state = self.inner.read();
206            let drb = state.table.get_epoch_drb(*e);
207            let fetcher = state.fetcher.clone();
208            (drb, fetcher)
209        };
210
211        if let Ok(drb_result) = epoch_drb {
212            Ok(drb_result)
213        } else {
214            let previous_epoch = match e.checked_sub(1) {
215                Some(epoch) => epoch,
216                None => {
217                    return Err(anyhow!("Missing initial DRB result for epoch {e:?}").into());
218                },
219            };
220
221            let drb_block_height = transition_block_for_epoch(previous_epoch, *epoch_height);
222            let stake_table = self.inner.read().table.stake_table(Some(previous_epoch));
223            let fetcher = fetcher.expect("get_epoch_drb called before set_leaf_fetcher_network");
224
225            let mut drb_leaf = None;
226
227            for node in stake_table {
228                if let Ok(leaf) = fetcher
229                    .read()
230                    .await
231                    .fetch_leaf(drb_block_height, node.signature_key)
232                    .await
233                {
234                    drb_leaf = Some(leaf);
235                    break;
236                }
237            }
238
239            match drb_leaf {
240                Some(leaf) => Ok(leaf.next_drb_result.expect(
241                    "We fetched a leaf that is missing a DRB result. This should be impossible.",
242                )),
243                None => Err(anyhow!(
244                    "Failed to fetch leaf from all nodes. Height: {drb_block_height}"
245                )
246                .into()),
247            }
248        }
249    }
250
251    fn add_da_committee(&self, first_epoch: EpochNumber, committee: Vec<PeerConfig<T>>) {
252        self.inner.write().table.add_da_committee(
253            *first_epoch,
254            committee.into_iter().map(Into::into).collect(),
255        );
256    }
257}
258
259#[derive(Debug, thiserror::Error)]
260#[error("strict membership error: {0}")]
261pub struct StrictMembershipError(#[from] anyhow::Error);
262
263/// Per-epoch snapshot for `StrictMembership`.
264///
265/// Materializes the stake-table views at construction time so accessors can
266/// return borrowed iterators.
267pub struct StrictEpochSnapshot<T, S>
268where
269    T: NodeType,
270    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
271{
272    epoch: EpochNumber,
273    first_epoch: Option<EpochNumber>,
274    has_drb: bool,
275    stake_table: Vec<PeerConfig<T>>,
276    da_stake_table: Vec<PeerConfig<T>>,
277    committee_keys: Vec<T::SignatureKey>,
278    da_committee_keys: Vec<T::SignatureKey>,
279    table: S,
280    _phantom: std::marker::PhantomData<T>,
281}
282
283impl<T, S> StrictEpochSnapshot<T, S>
284where
285    T: NodeType,
286    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
287{
288    fn build(
289        epoch: EpochNumber,
290        first_epoch: Option<EpochNumber>,
291        has_drb: bool,
292        table: S,
293    ) -> Self {
294        let stake_entries = table.stake_table(Some(*epoch));
295        let da_entries = table.da_stake_table(Some(*epoch));
296        let committee_keys = stake_entries
297            .iter()
298            .map(|e| e.signature_key.clone())
299            .collect();
300        let da_committee_keys = da_entries.iter().map(|e| e.signature_key.clone()).collect();
301        let stake_table = stake_entries.into_iter().map(Into::into).collect();
302        let da_stake_table = da_entries.into_iter().map(Into::into).collect();
303        Self {
304            epoch,
305            first_epoch,
306            has_drb,
307            stake_table,
308            da_stake_table,
309            committee_keys,
310            da_committee_keys,
311            table,
312            _phantom: std::marker::PhantomData,
313        }
314    }
315}
316
317impl<T, S> Clone for StrictEpochSnapshot<T, S>
318where
319    T: NodeType,
320    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
321{
322    fn clone(&self) -> Self {
323        Self {
324            epoch: self.epoch,
325            first_epoch: self.first_epoch,
326            has_drb: self.has_drb,
327            stake_table: self.stake_table.clone(),
328            da_stake_table: self.da_stake_table.clone(),
329            committee_keys: self.committee_keys.clone(),
330            da_committee_keys: self.da_committee_keys.clone(),
331            table: self.table.clone(),
332            _phantom: std::marker::PhantomData,
333        }
334    }
335}
336
337impl<T, S> Debug for StrictEpochSnapshot<T, S>
338where
339    T: NodeType,
340    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
341{
342    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
343        f.debug_struct("StrictEpochSnapshot")
344            .field("epoch", &self.epoch)
345            .field("first_epoch", &self.first_epoch)
346            .field("has_drb", &self.has_drb)
347            .field("table", &self.table)
348            .finish()
349    }
350}
351
352impl<T, S> MembershipSnapshot<T> for StrictEpochSnapshot<T, S>
353where
354    T: NodeType,
355    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
356{
357    type Error = StrictMembershipError;
358    type StakeTableHash = NoStakeTableHash;
359
360    fn epoch(&self) -> EpochNumber {
361        self.epoch
362    }
363
364    fn first_epoch(&self) -> Option<EpochNumber> {
365        self.first_epoch
366    }
367
368    fn has_drb(&self) -> bool {
369        self.has_drb
370    }
371
372    fn stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send {
373        self.stake_table.iter()
374    }
375
376    fn da_stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send {
377        self.da_stake_table.iter()
378    }
379
380    fn committee_members(
381        &self,
382        _: ViewNumber,
383    ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send {
384        self.committee_keys.iter()
385    }
386
387    fn da_committee_members(
388        &self,
389        _: ViewNumber,
390    ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send {
391        self.da_committee_keys.iter()
392    }
393
394    fn stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
395        self.table
396            .stake(key.clone(), Some(*self.epoch))
397            .map(Into::into)
398    }
399
400    fn da_stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
401        self.table
402            .da_stake(key.clone(), Some(*self.epoch))
403            .map(Into::into)
404    }
405
406    fn has_stake(&self, key: &T::SignatureKey) -> bool {
407        self.stake(key)
408            .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
409    }
410
411    fn has_da_stake(&self, key: &T::SignatureKey) -> bool {
412        self.da_stake(key)
413            .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
414    }
415
416    fn lookup_leader(&self, view: ViewNumber) -> Result<T::SignatureKey, Self::Error> {
417        Ok(self.table.lookup_leader(*view, Some(*self.epoch))?)
418    }
419}
420
421/// Pre-epoch snapshot for `StrictMembership`. Materializes views at
422/// construction so accessors can return borrowed iterators.
423pub struct StrictNonEpochSnapshot<T, S>
424where
425    T: NodeType,
426    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
427{
428    stake_table: Vec<PeerConfig<T>>,
429    da_stake_table: Vec<PeerConfig<T>>,
430    committee_keys: Vec<T::SignatureKey>,
431    da_committee_keys: Vec<T::SignatureKey>,
432    table: S,
433    _phantom: std::marker::PhantomData<T>,
434}
435
436impl<T, S> StrictNonEpochSnapshot<T, S>
437where
438    T: NodeType,
439    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
440{
441    fn build(table: S) -> Self {
442        let stake_entries = table.stake_table(None);
443        let da_entries = table.da_stake_table(None);
444        let committee_keys = stake_entries
445            .iter()
446            .map(|e| e.signature_key.clone())
447            .collect();
448        let da_committee_keys = da_entries.iter().map(|e| e.signature_key.clone()).collect();
449        let stake_table = stake_entries.into_iter().map(Into::into).collect();
450        let da_stake_table = da_entries.into_iter().map(Into::into).collect();
451        Self {
452            stake_table,
453            da_stake_table,
454            committee_keys,
455            da_committee_keys,
456            table,
457            _phantom: std::marker::PhantomData,
458        }
459    }
460}
461
462impl<T, S> Clone for StrictNonEpochSnapshot<T, S>
463where
464    T: NodeType,
465    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
466{
467    fn clone(&self) -> Self {
468        Self {
469            stake_table: self.stake_table.clone(),
470            da_stake_table: self.da_stake_table.clone(),
471            committee_keys: self.committee_keys.clone(),
472            da_committee_keys: self.da_committee_keys.clone(),
473            table: self.table.clone(),
474            _phantom: std::marker::PhantomData,
475        }
476    }
477}
478
479impl<T, S> Debug for StrictNonEpochSnapshot<T, S>
480where
481    T: NodeType,
482    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
483{
484    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
485        f.debug_struct("StrictNonEpochSnapshot")
486            .field("table", &self.table)
487            .finish()
488    }
489}
490
491impl<T, S> NonEpochMembershipSnapshot<T> for StrictNonEpochSnapshot<T, S>
492where
493    T: NodeType,
494    S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
495{
496    type Error = StrictMembershipError;
497
498    fn stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send + '_ {
499        self.stake_table.iter()
500    }
501
502    fn da_stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send + '_ {
503        self.da_stake_table.iter()
504    }
505
506    fn committee_members(
507        &self,
508        _: ViewNumber,
509    ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send + '_ {
510        self.committee_keys.iter()
511    }
512
513    fn da_committee_members(
514        &self,
515        _: ViewNumber,
516    ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send + '_ {
517        self.da_committee_keys.iter()
518    }
519
520    fn stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
521        self.table.stake(key.clone(), None).map(Into::into)
522    }
523
524    fn da_stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
525        self.table.da_stake(key.clone(), None).map(Into::into)
526    }
527
528    fn has_stake(&self, key: &T::SignatureKey) -> bool {
529        self.stake(key)
530            .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
531    }
532
533    fn has_da_stake(&self, key: &T::SignatureKey) -> bool {
534        self.da_stake(key)
535            .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
536    }
537
538    fn lookup_leader(&self, view: ViewNumber) -> Result<T::SignatureKey, Self::Error> {
539        Ok(self.table.lookup_leader(*view, None)?)
540    }
541}