1use std::collections::{BTreeMap, BTreeSet};
2
3use committable::{Commitment, Committable};
4use hotshot::types::SignatureKey;
5use hotshot_types::{
6 data::{EpochNumber, ViewNumber},
7 epoch_membership::{EpochMembership, EpochMembershipCoordinator},
8 message::UpgradeLock,
9 simple_vote::{HasEpoch, VersionedVoteData},
10 stake_table::StakeTableEntries,
11 traits::node_implementation::NodeType,
12 vote::{Certificate, Vote, VoteAccumulator},
13};
14use tokio::{
15 sync::mpsc::{self},
16 task::{AbortHandle, JoinSet},
17};
18use tracing::{debug, instrument, warn};
19
20pub struct VoteCollector<T: NodeType, V, C> {
21 tasks: JoinSet<C>,
26 accumulators: BTreeMap<ViewNumber, (mpsc::Sender<V>, AbortHandle)>,
27 completed_certificates: BTreeSet<ViewNumber>,
28 epoch_membership_coordinator: EpochMembershipCoordinator<T>,
29 membership_cache: BTreeMap<EpochNumber, EpochMembership<T>>,
30 upgrade_lock: UpgradeLock<T>,
31 pending_votes: Vec<V>,
33}
34
35impl<T, V, C> VoteCollector<T, V, C>
36where
37 T: NodeType,
38 V: Vote<T> + HasEpoch + Send + Sync + 'static,
39 C: Certificate<T, V::Commitment, Voteable = V::Commitment> + Send + Sync + 'static,
40{
41 #[instrument(level = "debug", skip_all)]
42 pub fn new(emc: EpochMembershipCoordinator<T>, lock: UpgradeLock<T>) -> Self {
43 Self {
44 accumulators: BTreeMap::new(),
45 completed_certificates: BTreeSet::new(),
46 epoch_membership_coordinator: emc,
47 membership_cache: BTreeMap::new(),
48 upgrade_lock: lock,
49 tasks: JoinSet::new(),
50 pending_votes: Vec::new(),
51 }
52 }
53
54 pub async fn next(&mut self) -> Option<C> {
55 loop {
56 match self.tasks.join_next().await {
57 Some(Ok(cert)) => {
58 if self.completed_certificates.contains(&cert.view_number()) {
59 continue;
60 }
61 self.completed_certificates.insert(cert.view_number());
62 return Some(cert);
63 },
64 Some(Err(e)) if e.is_cancelled() => {
65 debug!("Vote collection task cancelled: {e}");
66 },
67 Some(Err(e)) => {
68 warn!("Error in vote collection task: {e}");
69 },
70 None => return None,
71 }
72 }
73 }
74
75 pub async fn accumulate_vote(&mut self, vote: V) {
76 let view = vote.view_number();
77 if self.completed_certificates.contains(&view) {
78 return;
79 }
80 let Some(membership) = self.resolve_membership(&vote) else {
81 self.pending_votes.push(vote);
83 return;
84 };
85 let (tx, _abort_handle) = self.accumulators.entry(view).or_insert_with(|| {
86 let (tx, rx) = mpsc::channel(100);
87 let accumulator = VoteAccumulator::new(self.upgrade_lock.clone());
88 let upgrade_lock = self.upgrade_lock.clone();
89 let abort_handle = self.tasks.spawn(Self::run_per_view(
90 view,
91 rx,
92 accumulator,
93 membership,
94 upgrade_lock,
95 ));
96 (tx, abort_handle)
97 });
98 let _ = tx.send(vote).await;
99 }
100
101 pub async fn retry_pending_votes(&mut self) {
104 let pending = std::mem::take(&mut self.pending_votes);
105 for vote in pending {
106 self.accumulate_vote(vote).await;
107 }
108 }
109
110 fn resolve_membership(&mut self, vote: &V) -> Option<EpochMembership<T>> {
111 let epoch = vote.epoch()?;
112 if let Some(m) = self.membership_cache.get(&epoch) {
113 return Some(m.clone());
114 }
115 let m = self
116 .epoch_membership_coordinator
117 .membership_for_epoch(Some(epoch))
118 .ok()?;
119 self.membership_cache.insert(epoch, m.clone());
120 Some(m)
121 }
122
123 #[instrument(level = "debug", skip_all)]
124 async fn run_per_view(
125 _view: ViewNumber,
126 mut rx: mpsc::Receiver<V>,
127 mut accumulator: VoteAccumulator<T, V, C>,
128 membership: EpochMembership<T>,
129 upgrade_lock: UpgradeLock<T>,
130 ) -> C {
131 let mut votes = Vec::new();
132
133 while let Some(vote) = rx.recv().await {
134 if let Some(cert) = accumulator.accumulate(&vote, membership.clone()) {
135 let stake_table = C::stake_table(&membership);
136 let threshold = C::threshold(&membership);
137 match cert.is_valid_cert(
138 &StakeTableEntries::<T>::from(stake_table).0,
139 threshold,
140 &upgrade_lock,
141 ) {
142 Ok(()) => {
143 return cert;
144 },
145 Err(e) => {
146 warn!("Invalid certificate formed: {e}");
147 votes.push(vote);
148 votes.retain(|v: &V| {
151 let vote_commitment = generate_vote_commitment(v, &upgrade_lock);
152
153 vote_commitment.is_some_and(|commitment| {
154 v.signing_key()
155 .validate(&v.signature(), commitment.as_ref())
156 })
157 });
158 accumulator = VoteAccumulator::new(upgrade_lock.clone());
159 for vote in &votes {
160 if let Some(cert) = accumulator.accumulate(vote, membership.clone()) {
163 return cert;
164 }
165 }
166 },
167 }
168 } else {
169 votes.push(vote);
170 }
171 }
172 unreachable!()
173 }
174 pub fn gc(&mut self, view: ViewNumber, epoch: EpochNumber) {
175 let keep = self.accumulators.split_off(&view);
176 self.completed_certificates = self.completed_certificates.split_off(&view);
177 for (_, handle) in self.accumulators.values_mut() {
178 handle.abort();
179 }
180 self.accumulators = keep;
181 self.membership_cache = self.membership_cache.split_off(&epoch);
182 self.pending_votes.retain(|v| v.view_number() >= view);
183 }
184}
185
186fn generate_vote_commitment<T: NodeType, V: Vote<T>>(
187 vote: &V,
188 upgrade_lock: &UpgradeLock<T>,
189) -> Option<Commitment<VersionedVoteData<T, V::Commitment>>> {
190 match VersionedVoteData::new(vote.date().clone(), vote.view_number(), upgrade_lock) {
191 Ok(data) => Some(data.commit()),
192 Err(e) => {
193 tracing::warn!("Failed to generate versioned vote data: {e}");
194 None
195 },
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use std::{fmt::Debug, time::Duration};
202
203 use committable::Committable;
204 use hotshot::types::BLSPubKey;
205 use hotshot_example_types::node_types::TestTypes;
206 use hotshot_types::{
207 data::{EpochNumber, ViewNumber},
208 epoch_membership::EpochMembership,
209 simple_vote::{
210 HasEpoch, QuorumData2, QuorumVote2, SimpleVote, VersionedVoteData, Vote2Data,
211 },
212 stake_table::StakeTableEntries,
213 traits::signature_key::SignatureKey,
214 vote::{Certificate, HasViewNumber, Vote},
215 };
216 use tokio::{sync::mpsc, time::timeout};
217
218 use super::VoteCollector;
219 use crate::{
220 helpers::test_upgrade_lock,
221 message::{Certificate1, Certificate2, Vote2},
222 tests::common::utils::mock_membership,
223 };
224
225 const NUM_NODES: u64 = 10;
227 const THRESHOLD: u64 = 7;
229
230 const CERT_TIMEOUT: Duration = Duration::from_millis(100);
232 const NO_CERT_TIMEOUT: Duration = Duration::from_millis(500);
234
235 fn make_quorum_vote(
237 node_index: u64,
238 view: ViewNumber,
239 epoch: EpochNumber,
240 ) -> QuorumVote2<TestTypes> {
241 let (pub_key, priv_key) = BLSPubKey::generated_from_seed_indexed([0u8; 32], node_index);
242 let data = QuorumData2 {
243 leaf_commit: committable::RawCommitmentBuilder::new("FakeLeaf")
244 .u64(42)
245 .finalize(),
246 epoch: Some(epoch),
247 block_number: Some(1),
248 };
249 SimpleVote::create_signed_vote(data, view, &pub_key, &priv_key, &test_upgrade_lock())
250 .expect("Failed to sign vote")
251 }
252
253 fn vote_2_data() -> Vote2Data<TestTypes> {
254 Vote2Data {
255 leaf_commit: committable::RawCommitmentBuilder::new("FakeLeaf")
256 .u64(42)
257 .finalize(),
258 epoch: EpochNumber::genesis(),
259 block_number: 1,
260 }
261 }
262
263 fn make_vote2(node_index: u64, view: ViewNumber) -> Vote2<TestTypes> {
265 let (pub_key, priv_key) = BLSPubKey::generated_from_seed_indexed([0u8; 32], node_index);
266 let data = vote_2_data();
267 SimpleVote::create_signed_vote(data, view, &pub_key, &priv_key, &test_upgrade_lock())
268 .expect("Failed to sign vote")
269 }
270
271 fn make_invalid_vote2(node_index: u64, view: ViewNumber) -> Vote2<TestTypes> {
273 let (pub_key, _) = BLSPubKey::generated_from_seed_indexed([0u8; 32], node_index);
274 let (_, wrong_priv_key) = BLSPubKey::generated_from_seed_indexed([1u8; 32], node_index);
276 let data = vote_2_data();
277 let commit =
278 VersionedVoteData::<TestTypes, _>::new(data.clone(), view, &test_upgrade_lock())
279 .unwrap()
280 .commit();
281 let bad_sig = BLSPubKey::sign(&wrong_priv_key, commit.as_ref()).unwrap();
282 SimpleVote {
283 signature: (pub_key, bad_sig),
284 data,
285 view_number: view,
286 }
287 }
288
289 fn setup_cert1_task()
294 -> VoteCollector<TestTypes, QuorumVote2<TestTypes>, Certificate1<TestTypes>> {
295 setup_task::<QuorumVote2<TestTypes>, Certificate1<TestTypes>>()
296 }
297
298 fn setup_cert2_task() -> VoteCollector<TestTypes, Vote2<TestTypes>, Certificate2<TestTypes>> {
299 setup_task::<Vote2<TestTypes>, Certificate2<TestTypes>>()
300 }
301
302 fn setup_task<
304 V: Vote<TestTypes> + HasEpoch + Send + Sync + 'static,
305 C: Certificate<TestTypes, V::Commitment, Voteable = V::Commitment> + Send + Sync + 'static,
306 >() -> VoteCollector<TestTypes, V, C> {
307 let membership = mock_membership();
308 VoteCollector::<TestTypes, V, C>::new(membership, test_upgrade_lock())
309 }
310
311 async fn _collect_certs<T: std::fmt::Debug>(
313 cert_rx: &mut mpsc::Receiver<T>,
314 expected: usize,
315 ) -> Vec<T> {
316 let mut results = Vec::new();
317 for _ in 0..expected {
318 let cert = tokio::time::timeout(CERT_TIMEOUT, cert_rx.recv())
319 .await
320 .expect("Timed out waiting for certificate")
321 .expect("Cert channel closed unexpectedly");
322 results.push(cert);
323 }
324 results
325 }
326
327 async fn assert_no_certs<
329 V: Vote<TestTypes> + HasEpoch + Send + Sync + 'static,
330 C: Certificate<TestTypes, V::Commitment, Voteable = V::Commitment>
331 + Debug
332 + Send
333 + Sync
334 + 'static,
335 >(
336 task: &mut VoteCollector<TestTypes, V, C>,
337 ) {
338 let result = tokio::time::timeout(NO_CERT_TIMEOUT, task.next()).await;
339 match result {
340 Err(_) => { },
341 Ok(None) => { },
342 Ok(Some(cert)) => panic!("Expected no certificate but got one: {cert:?}"),
343 }
344 }
345
346 fn verify_cert<C, D>(cert: &C, expected_data: &D, membership: &EpochMembership<TestTypes>)
349 where
350 D: Committable,
351 C: Certificate<TestTypes, D, Voteable = D>,
352 {
353 assert_eq!(
355 cert.data().commit(),
356 expected_data.commit(),
357 "Certificate data commitment does not match expected vote data"
358 );
359
360 let stake_table = C::stake_table(membership);
362 let stake_table_entries = StakeTableEntries::<TestTypes>::from(stake_table).0;
363 let threshold = C::threshold(membership);
364 cert.is_valid_cert(&stake_table_entries, threshold, &test_upgrade_lock())
365 .expect("Certificate signature validation failed");
366 }
367
368 #[tokio::test]
373 async fn test_cert1_single_view_happy_path() {
374 let mut task = setup_cert1_task();
375 let view = ViewNumber::new(1);
376 let epoch = EpochNumber::genesis();
377 let expected_data = QuorumData2 {
378 leaf_commit: committable::RawCommitmentBuilder::new("FakeLeaf")
379 .u64(42)
380 .finalize(),
381 epoch: Some(epoch),
382 block_number: Some(1),
383 };
384
385 for i in 0..THRESHOLD {
386 task.accumulate_vote(make_quorum_vote(i, view, epoch)).await;
387 }
388
389 let cert = timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap();
390 assert_eq!(cert.view_number(), view);
391
392 let membership = mock_membership();
393 let epoch_membership = membership.membership_for_epoch(Some(epoch)).unwrap();
394 verify_cert(&cert, &expected_data, &epoch_membership);
395 }
396
397 #[tokio::test]
400 async fn test_cert1_multiple_views_parallel() {
401 let mut task = setup_cert1_task();
402 let epoch = EpochNumber::genesis();
403 let expected_data = QuorumData2 {
404 leaf_commit: committable::RawCommitmentBuilder::new("FakeLeaf")
405 .u64(42)
406 .finalize(),
407 epoch: Some(epoch),
408 block_number: Some(1),
409 };
410
411 let views = [ViewNumber::new(1), ViewNumber::new(2), ViewNumber::new(3)];
412
413 for i in 0..THRESHOLD {
415 for &view in &views {
416 task.accumulate_vote(make_quorum_vote(i, view, epoch)).await;
417 }
418 }
419 let mut certs = Vec::new();
420 for _ in 0..views.len() {
421 certs.push(timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap());
422 }
423 assert_eq!(
424 certs.len(),
425 views.len(),
426 "Expected one Certificate1 per view"
427 );
428 let mut cert_views: Vec<_> = certs.iter().map(|c| c.view_number()).collect();
429 cert_views.sort();
430 assert_eq!(cert_views, views.to_vec());
431
432 let membership = mock_membership();
433 let epoch_membership = membership.membership_for_epoch(Some(epoch)).unwrap();
434 for cert in &certs {
435 verify_cert(cert, &expected_data, &epoch_membership);
436 }
437 }
438
439 #[tokio::test]
444 async fn test_cert2_single_view_happy_path() {
445 let mut task = setup_cert2_task();
446 let view = ViewNumber::new(1);
447 let epoch = EpochNumber::genesis();
448 let expected_data = vote_2_data();
449
450 for i in 0..THRESHOLD {
451 task.accumulate_vote(make_vote2(i, view)).await;
452 }
453
454 let cert = timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap();
455 assert_eq!(cert.view_number(), view);
456
457 let membership = mock_membership();
458 let epoch_membership = membership.membership_for_epoch(Some(epoch)).unwrap();
459 verify_cert(&cert, &expected_data, &epoch_membership);
460 }
461
462 #[tokio::test]
465 async fn test_cert2_multiple_views_parallel() {
466 let mut task = setup_cert2_task();
467 let epoch = EpochNumber::genesis();
468 let expected_data = vote_2_data();
469
470 let views = [ViewNumber::new(5), ViewNumber::new(6), ViewNumber::new(7)];
471
472 for i in 0..THRESHOLD {
473 for &view in &views {
474 task.accumulate_vote(make_vote2(i, view)).await;
475 }
476 }
477
478 let mut certs = Vec::new();
479 for _ in 0..views.len() {
480 certs.push(timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap());
481 }
482 assert_eq!(
483 certs.len(),
484 views.len(),
485 "Expected one Certificate2 per view"
486 );
487 let mut cert_views: Vec<_> = certs.iter().map(|c| c.view_number()).collect();
488 cert_views.sort();
489 assert_eq!(cert_views, views.to_vec());
490
491 let membership = mock_membership();
492 let epoch_membership = membership.membership_for_epoch(Some(epoch)).unwrap();
493 for cert in &certs {
494 verify_cert(cert, &expected_data, &epoch_membership);
495 }
496 }
497
498 #[tokio::test]
502 async fn test_cert1_below_threshold_no_certificate() {
503 let mut task = setup_cert1_task();
504 let view = ViewNumber::new(1);
505 let epoch = EpochNumber::genesis();
506
507 for i in 0..(THRESHOLD - 1) {
508 task.accumulate_vote(make_quorum_vote(i, view, epoch)).await;
509 }
510
511 assert_no_certs(&mut task).await;
512 }
513
514 #[tokio::test]
516 async fn test_cert1_duplicate_votes_ignored() {
517 let mut task = setup_cert1_task();
518 let view = ViewNumber::new(1);
519 let epoch = EpochNumber::genesis();
520
521 for i in 0..6 {
523 task.accumulate_vote(make_quorum_vote(i, view, epoch)).await;
524 }
525 for _ in 0..5 {
527 task.accumulate_vote(make_quorum_vote(0, view, epoch)).await;
528 }
529
530 assert_no_certs(&mut task).await;
531 }
532
533 #[tokio::test]
537 async fn test_cert2_below_threshold_no_certificate() {
538 let mut task = setup_cert2_task();
539 let view = ViewNumber::new(1);
540
541 for i in 0..(THRESHOLD - 1) {
542 task.accumulate_vote(make_vote2(i, view)).await;
543 }
544
545 assert_no_certs(&mut task).await;
546 }
547
548 #[tokio::test]
550 async fn test_cert2_duplicate_votes_ignored() {
551 let mut task = setup_cert2_task();
552 let view = ViewNumber::new(1);
553
554 for i in 0..6 {
556 task.accumulate_vote(make_vote2(i, view)).await;
557 }
558 for _ in 0..5 {
560 task.accumulate_vote(make_vote2(0, view)).await;
561 }
562
563 assert_no_certs(&mut task).await;
564 }
565
566 #[tokio::test]
568 async fn test_cert2_invalid_signature_rejected() {
569 let mut task = setup_cert2_task();
570 let view = ViewNumber::new(1);
571
572 for i in 0..6 {
574 task.accumulate_vote(make_vote2(i, view)).await;
575 }
576 for i in 6..NUM_NODES {
578 task.accumulate_vote(make_invalid_vote2(i, view)).await;
579 }
580
581 assert_no_certs(&mut task).await;
582 }
583
584 #[tokio::test]
586 async fn test_cert2_invalid_signature_recovery() {
587 let mut task = setup_cert2_task();
588 let view = ViewNumber::new(1);
589 let epoch = EpochNumber::genesis();
590
591 for i in 0..6 {
593 task.accumulate_vote(make_vote2(i, view)).await;
594 }
595 for i in 6..8 {
597 task.accumulate_vote(make_invalid_vote2(i, view)).await;
598 }
599 assert_no_certs(&mut task).await;
600
601 task.accumulate_vote(make_vote2(9, view)).await;
602
603 let cert = timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap();
604 assert_no_certs(&mut task).await;
605 let membership = mock_membership();
606 let epoch_membership = membership.membership_for_epoch(Some(epoch)).unwrap();
607 verify_cert(&cert, &vote_2_data(), &epoch_membership);
608 }
609
610 #[tokio::test]
612 async fn test_cert2_channel_closed_early() {
613 let mut task = setup_cert2_task();
614 let view = ViewNumber::new(1);
615
616 for i in 0..3 {
617 task.accumulate_vote(make_vote2(i, view)).await;
618 }
619 assert_no_certs(&mut task).await;
620 }
621
622 #[tokio::test]
626 async fn test_cert2_partial_views_only_complete_one_certifies() {
627 let mut task = setup_cert2_task();
628
629 let complete_view = ViewNumber::new(1);
630 let partial_view = ViewNumber::new(2);
631
632 for i in 0..THRESHOLD {
634 task.accumulate_vote(make_vote2(i, complete_view)).await;
635 }
636
637 for i in 0..3 {
639 task.accumulate_vote(make_vote2(i, partial_view)).await;
640 }
641
642 let cert = timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap();
644 assert_no_certs(&mut task).await;
645 assert_eq!(cert.view_number(), complete_view);
646 }
647
648 #[tokio::test]
650 async fn test_cert2_extra_votes_after_threshold_no_duplicate_cert() {
651 let mut task = setup_cert2_task();
652 let view = ViewNumber::new(1);
653
654 for i in 0..NUM_NODES {
656 task.accumulate_vote(make_vote2(i, view)).await;
657 }
658
659 let cert = timeout(CERT_TIMEOUT, task.next()).await.unwrap().unwrap();
661 assert_eq!(cert.view_number(), view);
662
663 assert_no_certs(&mut task).await;
665 }
666
667 #[tokio::test]
669 async fn test_cert2_conflicting_data_same_view_no_certificate() {
670 let mut task = setup_cert2_task();
671 let view = ViewNumber::new(1);
672
673 for i in 0..6 {
675 task.accumulate_vote(make_vote2(i, view)).await;
676 }
677
678 for i in 6..NUM_NODES {
680 let (pub_key, priv_key) = BLSPubKey::generated_from_seed_indexed([0u8; 32], i);
681 let data = Vote2Data {
682 leaf_commit: committable::RawCommitmentBuilder::new("FakeLeaf")
683 .u64(1000)
685 .finalize(),
686 epoch: EpochNumber::genesis(),
687 block_number: 1,
688 };
689 let vote = SimpleVote::create_signed_vote(
690 data,
691 view,
692 &pub_key,
693 &priv_key,
694 &test_upgrade_lock(),
695 )
696 .expect("Failed to sign vote");
697 task.accumulate_vote(vote).await;
698 }
699 assert_no_certs(&mut task).await;
700 }
701}