1use std::{
2 collections::HashSet,
3 fmt::{self, Debug},
4 sync::Arc,
5};
6
7use alloy::primitives::U256;
8use anyhow::anyhow;
9use async_broadcast::Receiver;
10use async_lock::RwLock as AsyncRwLock;
11use hotshot_types::{
12 PeerConfig,
13 data::{BlockNumber, EpochNumber, Leaf2, ViewNumber},
14 drb::DrbResult,
15 event::Event,
16 traits::{
17 block_contents::BlockHeader,
18 election::{Membership, MembershipSnapshot, NoStakeTableHash, NonEpochMembershipSnapshot},
19 leaf_fetcher_network::LeafFetcherNetwork,
20 node_implementation::NodeType,
21 signature_key::StakeTableEntryType,
22 },
23 utils::{epoch_from_block_number, root_block_in_epoch, transition_block_for_epoch},
24};
25use parking_lot::RwLock;
26
27use crate::{
28 membership::{TestableMembership, fetcher::Leaf2Fetcher, stake_table::TestStakeTable},
29 storage_types::TestStorage,
30};
31
32#[derive(Clone)]
33pub struct StrictMembership<T, S>
34where
35 T: NodeType,
36 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
37{
38 inner: Arc<RwLock<Inner<T, S>>>,
39 epoch_height: BlockNumber,
40}
41
42struct Inner<T: NodeType, S> {
43 table: S,
44 epochs: HashSet<EpochNumber>,
45 drbs: HashSet<EpochNumber>,
46 fetcher: Option<Arc<AsyncRwLock<Leaf2Fetcher<T>>>>,
47}
48
49impl<T, S> Debug for StrictMembership<T, S>
50where
51 T: NodeType,
52 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
53{
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
55 let inner = self.inner.read();
56 f.debug_struct("StrictMembership")
57 .field("table", &inner.table)
58 .field("epochs", &inner.epochs)
59 .field("drbs", &inner.drbs)
60 .finish()
61 }
62}
63
64impl<T, S> TestableMembership<T> for StrictMembership<T, S>
65where
66 T: NodeType,
67 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
68{
69 fn new(
70 quorum_members: Vec<PeerConfig<T>>,
71 da_members: Vec<PeerConfig<T>>,
72 _public_key: T::SignatureKey,
73 epoch_height: u64,
74 ) -> Self {
75 Self {
76 inner: Arc::new(RwLock::new(Inner {
77 table: TestStakeTable::new(
78 quorum_members.into_iter().map(Into::into).collect(),
79 da_members.into_iter().map(Into::into).collect(),
80 ),
81 epochs: HashSet::new(),
82 drbs: HashSet::new(),
83 fetcher: None,
84 })),
85 epoch_height: epoch_height.into(),
86 }
87 }
88
89 fn set_leaf_fetcher(
90 &self,
91 network: Arc<dyn LeafFetcherNetwork<T>>,
92 storage: TestStorage<T>,
93 public_key: T::SignatureKey,
94 channel: Receiver<Event<T>>,
95 ) {
96 let mut fetcher = Leaf2Fetcher::new(network, storage, public_key);
97 fetcher.set_external_channel(channel);
98 self.inner.write().fetcher = Some(Arc::new(AsyncRwLock::new(fetcher)));
99 }
100}
101
102impl<T: NodeType, S> Inner<T, S> {
103 fn assert_has_stake_table(&self, epoch: Option<EpochNumber>) {
104 let Some(epoch) = epoch else {
105 return;
106 };
107 assert!(
108 self.epochs.contains(&epoch),
109 "Failed stake table check for epoch {epoch}"
110 );
111 }
112}
113
114impl<T, S> Membership<T> for StrictMembership<T, S>
115where
116 T: NodeType,
117 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
118{
119 type Error = StrictMembershipError;
120 type Snapshot = StrictEpochSnapshot<T, S>;
121 type NonEpochSnapshot = StrictNonEpochSnapshot<T, S>;
122
123 fn snapshot(&self, epoch: EpochNumber) -> Option<Self::Snapshot> {
124 let inner = self.inner.read();
125 if !inner.epochs.contains(&epoch) {
126 return None;
127 }
128 let has_drb = inner.drbs.contains(&epoch);
129 let first_epoch = inner.table.first_epoch().map(EpochNumber::new);
130 Some(StrictEpochSnapshot::build(
131 epoch,
132 first_epoch,
133 has_drb,
134 inner.table.clone(),
135 ))
136 }
137
138 fn non_epoch_snapshot(&self) -> Self::NonEpochSnapshot {
139 StrictNonEpochSnapshot::build(self.inner.read().table.clone())
140 }
141
142 fn add_drb_result(&self, e: EpochNumber, drb: DrbResult) {
143 let mut inner = self.inner.write();
144 inner.assert_has_stake_table(Some(e));
145 inner.drbs.insert(e);
146 inner.table.add_drb_result(*e, drb);
147 }
148
149 fn first_epoch(&self) -> Option<EpochNumber> {
150 self.inner.read().table.first_epoch().map(EpochNumber::new)
151 }
152
153 fn set_first_epoch(&self, e: EpochNumber, initial_drb_result: DrbResult) {
154 let mut inner = self.inner.write();
155 inner.epochs.insert(e);
156 inner.epochs.insert(e + 1);
157
158 inner.drbs.insert(e);
159 inner.drbs.insert(e + 1);
160
161 inner.table.set_first_epoch(*e, initial_drb_result);
162 }
163
164 async fn add_epoch_root(&self, hdr: T::BlockHeader) -> Result<(), Self::Error> {
165 let epoch = epoch_from_block_number(hdr.block_number(), *self.epoch_height) + 2;
166
167 let mut inner = self.inner.write();
168 inner.epochs.insert(EpochNumber::new(epoch));
169 inner.table.add_epoch_root(epoch);
170
171 Ok(())
172 }
173
174 async fn get_epoch_root(&self, e: EpochNumber) -> Result<Leaf2<T>, Self::Error> {
175 let block_height = root_block_in_epoch(*e, *self.epoch_height);
176
177 let (stake_table, fetcher) = {
178 let inner = self.inner.read();
179 let table = inner.table.stake_table(Some(*e));
180 let fetcher = inner
181 .fetcher
182 .clone()
183 .expect("get_epoch_root called before set_leaf_fetcher_network");
184 (table, fetcher)
185 };
186
187 for node in stake_table {
188 if let Ok(leaf) = fetcher
189 .read()
190 .await
191 .fetch_leaf(block_height, node.signature_key)
192 .await
193 {
194 return Ok(leaf);
195 }
196 }
197
198 Err(anyhow!("Failed to fetch epoch root from any peer").into())
199 }
200
201 async fn get_epoch_drb(&self, e: EpochNumber) -> Result<DrbResult, Self::Error> {
202 let epoch_height = self.epoch_height;
203
204 let (epoch_drb, fetcher) = {
205 let state = self.inner.read();
206 let drb = state.table.get_epoch_drb(*e);
207 let fetcher = state.fetcher.clone();
208 (drb, fetcher)
209 };
210
211 if let Ok(drb_result) = epoch_drb {
212 Ok(drb_result)
213 } else {
214 let previous_epoch = match e.checked_sub(1) {
215 Some(epoch) => epoch,
216 None => {
217 return Err(anyhow!("Missing initial DRB result for epoch {e:?}").into());
218 },
219 };
220
221 let drb_block_height = transition_block_for_epoch(previous_epoch, *epoch_height);
222 let stake_table = self.inner.read().table.stake_table(Some(previous_epoch));
223 let fetcher = fetcher.expect("get_epoch_drb called before set_leaf_fetcher_network");
224
225 let mut drb_leaf = None;
226
227 for node in stake_table {
228 if let Ok(leaf) = fetcher
229 .read()
230 .await
231 .fetch_leaf(drb_block_height, node.signature_key)
232 .await
233 {
234 drb_leaf = Some(leaf);
235 break;
236 }
237 }
238
239 match drb_leaf {
240 Some(leaf) => Ok(leaf.next_drb_result.expect(
241 "We fetched a leaf that is missing a DRB result. This should be impossible.",
242 )),
243 None => Err(anyhow!(
244 "Failed to fetch leaf from all nodes. Height: {drb_block_height}"
245 )
246 .into()),
247 }
248 }
249 }
250
251 fn add_da_committee(&self, first_epoch: EpochNumber, committee: Vec<PeerConfig<T>>) {
252 self.inner.write().table.add_da_committee(
253 *first_epoch,
254 committee.into_iter().map(Into::into).collect(),
255 );
256 }
257}
258
259#[derive(Debug, thiserror::Error)]
260#[error("strict membership error: {0}")]
261pub struct StrictMembershipError(#[from] anyhow::Error);
262
263pub struct StrictEpochSnapshot<T, S>
268where
269 T: NodeType,
270 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
271{
272 epoch: EpochNumber,
273 first_epoch: Option<EpochNumber>,
274 has_drb: bool,
275 stake_table: Vec<PeerConfig<T>>,
276 da_stake_table: Vec<PeerConfig<T>>,
277 committee_keys: Vec<T::SignatureKey>,
278 da_committee_keys: Vec<T::SignatureKey>,
279 table: S,
280 _phantom: std::marker::PhantomData<T>,
281}
282
283impl<T, S> StrictEpochSnapshot<T, S>
284where
285 T: NodeType,
286 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
287{
288 fn build(
289 epoch: EpochNumber,
290 first_epoch: Option<EpochNumber>,
291 has_drb: bool,
292 table: S,
293 ) -> Self {
294 let stake_entries = table.stake_table(Some(*epoch));
295 let da_entries = table.da_stake_table(Some(*epoch));
296 let committee_keys = stake_entries
297 .iter()
298 .map(|e| e.signature_key.clone())
299 .collect();
300 let da_committee_keys = da_entries.iter().map(|e| e.signature_key.clone()).collect();
301 let stake_table = stake_entries.into_iter().map(Into::into).collect();
302 let da_stake_table = da_entries.into_iter().map(Into::into).collect();
303 Self {
304 epoch,
305 first_epoch,
306 has_drb,
307 stake_table,
308 da_stake_table,
309 committee_keys,
310 da_committee_keys,
311 table,
312 _phantom: std::marker::PhantomData,
313 }
314 }
315}
316
317impl<T, S> Clone for StrictEpochSnapshot<T, S>
318where
319 T: NodeType,
320 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
321{
322 fn clone(&self) -> Self {
323 Self {
324 epoch: self.epoch,
325 first_epoch: self.first_epoch,
326 has_drb: self.has_drb,
327 stake_table: self.stake_table.clone(),
328 da_stake_table: self.da_stake_table.clone(),
329 committee_keys: self.committee_keys.clone(),
330 da_committee_keys: self.da_committee_keys.clone(),
331 table: self.table.clone(),
332 _phantom: std::marker::PhantomData,
333 }
334 }
335}
336
337impl<T, S> Debug for StrictEpochSnapshot<T, S>
338where
339 T: NodeType,
340 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
341{
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
343 f.debug_struct("StrictEpochSnapshot")
344 .field("epoch", &self.epoch)
345 .field("first_epoch", &self.first_epoch)
346 .field("has_drb", &self.has_drb)
347 .field("table", &self.table)
348 .finish()
349 }
350}
351
352impl<T, S> MembershipSnapshot<T> for StrictEpochSnapshot<T, S>
353where
354 T: NodeType,
355 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
356{
357 type Error = StrictMembershipError;
358 type StakeTableHash = NoStakeTableHash;
359
360 fn epoch(&self) -> EpochNumber {
361 self.epoch
362 }
363
364 fn first_epoch(&self) -> Option<EpochNumber> {
365 self.first_epoch
366 }
367
368 fn has_drb(&self) -> bool {
369 self.has_drb
370 }
371
372 fn stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send {
373 self.stake_table.iter()
374 }
375
376 fn da_stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send {
377 self.da_stake_table.iter()
378 }
379
380 fn committee_members(
381 &self,
382 _: ViewNumber,
383 ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send {
384 self.committee_keys.iter()
385 }
386
387 fn da_committee_members(
388 &self,
389 _: ViewNumber,
390 ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send {
391 self.da_committee_keys.iter()
392 }
393
394 fn stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
395 self.table
396 .stake(key.clone(), Some(*self.epoch))
397 .map(Into::into)
398 }
399
400 fn da_stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
401 self.table
402 .da_stake(key.clone(), Some(*self.epoch))
403 .map(Into::into)
404 }
405
406 fn has_stake(&self, key: &T::SignatureKey) -> bool {
407 self.stake(key)
408 .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
409 }
410
411 fn has_da_stake(&self, key: &T::SignatureKey) -> bool {
412 self.da_stake(key)
413 .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
414 }
415
416 fn lookup_leader(&self, view: ViewNumber) -> Result<T::SignatureKey, Self::Error> {
417 Ok(self.table.lookup_leader(*view, Some(*self.epoch))?)
418 }
419}
420
421pub struct StrictNonEpochSnapshot<T, S>
424where
425 T: NodeType,
426 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
427{
428 stake_table: Vec<PeerConfig<T>>,
429 da_stake_table: Vec<PeerConfig<T>>,
430 committee_keys: Vec<T::SignatureKey>,
431 da_committee_keys: Vec<T::SignatureKey>,
432 table: S,
433 _phantom: std::marker::PhantomData<T>,
434}
435
436impl<T, S> StrictNonEpochSnapshot<T, S>
437where
438 T: NodeType,
439 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
440{
441 fn build(table: S) -> Self {
442 let stake_entries = table.stake_table(None);
443 let da_entries = table.da_stake_table(None);
444 let committee_keys = stake_entries
445 .iter()
446 .map(|e| e.signature_key.clone())
447 .collect();
448 let da_committee_keys = da_entries.iter().map(|e| e.signature_key.clone()).collect();
449 let stake_table = stake_entries.into_iter().map(Into::into).collect();
450 let da_stake_table = da_entries.into_iter().map(Into::into).collect();
451 Self {
452 stake_table,
453 da_stake_table,
454 committee_keys,
455 da_committee_keys,
456 table,
457 _phantom: std::marker::PhantomData,
458 }
459 }
460}
461
462impl<T, S> Clone for StrictNonEpochSnapshot<T, S>
463where
464 T: NodeType,
465 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
466{
467 fn clone(&self) -> Self {
468 Self {
469 stake_table: self.stake_table.clone(),
470 da_stake_table: self.da_stake_table.clone(),
471 committee_keys: self.committee_keys.clone(),
472 da_committee_keys: self.da_committee_keys.clone(),
473 table: self.table.clone(),
474 _phantom: std::marker::PhantomData,
475 }
476 }
477}
478
479impl<T, S> Debug for StrictNonEpochSnapshot<T, S>
480where
481 T: NodeType,
482 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
483{
484 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
485 f.debug_struct("StrictNonEpochSnapshot")
486 .field("table", &self.table)
487 .finish()
488 }
489}
490
491impl<T, S> NonEpochMembershipSnapshot<T> for StrictNonEpochSnapshot<T, S>
492where
493 T: NodeType,
494 S: TestStakeTable<T::SignatureKey, T::StateSignatureKey>,
495{
496 type Error = StrictMembershipError;
497
498 fn stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send + '_ {
499 self.stake_table.iter()
500 }
501
502 fn da_stake_table(&self) -> impl ExactSizeIterator<Item = &PeerConfig<T>> + Send + '_ {
503 self.da_stake_table.iter()
504 }
505
506 fn committee_members(
507 &self,
508 _: ViewNumber,
509 ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send + '_ {
510 self.committee_keys.iter()
511 }
512
513 fn da_committee_members(
514 &self,
515 _: ViewNumber,
516 ) -> impl ExactSizeIterator<Item = &T::SignatureKey> + Send + '_ {
517 self.da_committee_keys.iter()
518 }
519
520 fn stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
521 self.table.stake(key.clone(), None).map(Into::into)
522 }
523
524 fn da_stake(&self, key: &T::SignatureKey) -> Option<PeerConfig<T>> {
525 self.table.da_stake(key.clone(), None).map(Into::into)
526 }
527
528 fn has_stake(&self, key: &T::SignatureKey) -> bool {
529 self.stake(key)
530 .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
531 }
532
533 fn has_da_stake(&self, key: &T::SignatureKey) -> bool {
534 self.da_stake(key)
535 .is_some_and(|x| x.stake_table_entry.stake() > U256::ZERO)
536 }
537
538 fn lookup_leader(&self, view: ViewNumber) -> Result<T::SignatureKey, Self::Error> {
539 Ok(self.table.lookup_leader(*view, None)?)
540 }
541}