1use std::{borrow::Borrow, collections::HashSet, iter::once, str::FromStr, sync::Arc};
2
3use alloy::primitives::{
4 Address, B256, U256,
5 utils::{ParseUnits, parse_units},
6};
7use anyhow::{Context, bail, ensure};
8use ark_serialize::{
9 CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate,
10};
11use espresso_utils::{
12 impl_serde_from_string_or_integer, impl_to_fixed_bytes, ser::FromStringOrInteger,
13};
14use hotshot::types::BLSPubKey;
15use hotshot_contract_adapter::reward::RewardProofSiblings;
16use hotshot_types::{
17 data::{EpochNumber, ViewNumber},
18 epoch_membership::EpochMembershipCoordinator,
19 stake_table::HSStakeTable,
20 traits::election::{Membership, MembershipSnapshot},
21 utils::epoch_from_block_number,
22};
23use jf_merkle_tree_compat::{
24 ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme, LookupResult,
25 MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme, prelude::MerkleNode,
26};
27use num_traits::CheckedSub;
28use tokio::task::JoinHandle;
29use vbs::version::Version;
30use versions::{DRB_AND_HEADER_UPGRADE_VERSION, EPOCH_REWARD_VERSION, EPOCH_VERSION};
31
32use super::{
33 Leaf2, NodeState, ValidatedState,
34 v0_3::{AuthenticatedValidator, COMMISSION_BASIS_POINTS, RewardAmount},
35 v0_4::{
36 RewardAccountProofV2, RewardAccountQueryDataV2, RewardAccountV2, RewardMerkleCommitmentV2,
37 RewardMerkleProofV2, RewardMerkleTreeV2, forgotten_accounts_include,
38 },
39};
40use crate::{
41 EpochSnapshot, FeeAccount, SeqTypes,
42 eth_signature_key::EthKeyPair,
43 v0_3::{
44 RewardAccountProofV1, RewardAccountV1, RewardMerkleCommitmentV1, RewardMerkleProofV1,
45 RewardMerkleTreeV1,
46 },
47 v0_4::{Delta, REWARD_MERKLE_TREE_V2_ARITY, REWARD_MERKLE_TREE_V2_HEIGHT},
48 v0_5::LeaderCounts,
49};
50
51impl_serde_from_string_or_integer!(RewardAmount);
52impl_to_fixed_bytes!(RewardAmount, U256);
53
54impl From<u64> for RewardAmount {
55 fn from(amt: u64) -> Self {
56 Self(U256::from(amt))
57 }
58}
59
60impl CheckedSub for RewardAmount {
61 fn checked_sub(&self, v: &Self) -> Option<Self> {
62 self.0.checked_sub(v.0).map(RewardAmount)
63 }
64}
65
66impl FromStr for RewardAmount {
67 type Err = <U256 as FromStr>::Err;
68
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 Ok(Self(s.parse()?))
71 }
72}
73
74impl FromStringOrInteger for RewardAmount {
75 type Binary = U256;
76 type Integer = u64;
77
78 fn from_binary(b: Self::Binary) -> anyhow::Result<Self> {
79 Ok(Self(b))
80 }
81
82 fn from_integer(i: Self::Integer) -> anyhow::Result<Self> {
83 Ok(i.into())
84 }
85
86 fn from_string(s: String) -> anyhow::Result<Self> {
87 if let Some(s) = s.strip_prefix("0x") {
90 return Ok(Self(s.parse()?));
91 }
92
93 let (base, unit) = s
95 .split_once(char::is_whitespace)
96 .unwrap_or((s.as_str(), "wei"));
97 match parse_units(base, unit)? {
98 ParseUnits::U256(n) => Ok(Self(n)),
99 ParseUnits::I256(_) => bail!("amount cannot be negative"),
100 }
101 }
102
103 fn to_binary(&self) -> anyhow::Result<Self::Binary> {
104 Ok(self.0)
105 }
106
107 fn to_string(&self) -> anyhow::Result<String> {
108 Ok(format!("{self}"))
109 }
110}
111
112impl RewardAmount {
113 pub fn as_u64(&self) -> Option<u64> {
114 if self.0 <= U256::from(u64::MAX) {
115 Some(self.0.to::<u64>())
116 } else {
117 None
118 }
119 }
120}
121
122impl From<[u8; 20]> for RewardAccountV1 {
123 fn from(bytes: [u8; 20]) -> Self {
124 Self(Address::from(bytes))
125 }
126}
127
128impl AsRef<[u8]> for RewardAccountV1 {
129 fn as_ref(&self) -> &[u8] {
130 self.0.as_slice()
131 }
132}
133
134impl<const ARITY: usize> ToTraversalPath<ARITY> for RewardAccountV1 {
135 fn to_traversal_path(&self, height: usize) -> Vec<usize> {
136 self.0
137 .as_slice()
138 .iter()
139 .take(height)
140 .map(|i| *i as usize)
141 .collect()
142 }
143}
144
145impl RewardAccountV2 {
146 pub fn address(&self) -> Address {
148 self.0
149 }
150 pub fn as_bytes(&self) -> &[u8] {
152 self.0.as_slice()
153 }
154 pub fn to_fixed_bytes(self) -> [u8; 20] {
156 self.0.into_array()
157 }
158 pub fn test_key_pair() -> EthKeyPair {
159 EthKeyPair::from_mnemonic(
160 "test test test test test test test test test test test junk",
161 0u32,
162 )
163 .unwrap()
164 }
165}
166
167impl RewardAccountV1 {
168 pub fn address(&self) -> Address {
170 self.0
171 }
172 pub fn as_bytes(&self) -> &[u8] {
174 self.0.as_slice()
175 }
176 pub fn to_fixed_bytes(self) -> [u8; 20] {
178 self.0.into_array()
179 }
180 pub fn test_key_pair() -> EthKeyPair {
181 EthKeyPair::from_mnemonic(
182 "test test test test test test test test test test test junk",
183 0u32,
184 )
185 .unwrap()
186 }
187}
188
189impl FromStr for RewardAccountV2 {
190 type Err = anyhow::Error;
191
192 fn from_str(s: &str) -> Result<Self, Self::Err> {
193 Ok(Self(s.parse()?))
194 }
195}
196
197impl FromStr for RewardAccountV1 {
198 type Err = anyhow::Error;
199
200 fn from_str(s: &str) -> Result<Self, Self::Err> {
201 Ok(Self(s.parse()?))
202 }
203}
204
205impl Valid for RewardAmount {
206 fn check(&self) -> Result<(), SerializationError> {
207 Ok(())
208 }
209}
210
211impl Valid for RewardAccountV2 {
212 fn check(&self) -> Result<(), SerializationError> {
213 Ok(())
214 }
215}
216
217impl Valid for RewardAccountV1 {
218 fn check(&self) -> Result<(), SerializationError> {
219 Ok(())
220 }
221}
222
223impl CanonicalSerialize for RewardAmount {
224 fn serialize_with_mode<W: std::io::prelude::Write>(
225 &self,
226 mut writer: W,
227 _compress: Compress,
228 ) -> Result<(), SerializationError> {
229 Ok(writer.write_all(&self.to_fixed_bytes())?)
230 }
231
232 fn serialized_size(&self, _compress: Compress) -> usize {
233 core::mem::size_of::<U256>()
234 }
235}
236impl CanonicalDeserialize for RewardAmount {
237 fn deserialize_with_mode<R: Read>(
238 mut reader: R,
239 _compress: Compress,
240 _validate: Validate,
241 ) -> Result<Self, SerializationError> {
242 let mut bytes = [0u8; core::mem::size_of::<U256>()];
243 reader.read_exact(&mut bytes)?;
244 let value = U256::from_le_slice(&bytes);
245 Ok(Self(value))
246 }
247}
248
249impl CanonicalSerialize for RewardAccountV2 {
250 fn serialize_with_mode<W: std::io::prelude::Write>(
251 &self,
252 mut writer: W,
253 _compress: Compress,
254 ) -> Result<(), SerializationError> {
255 Ok(writer.write_all(self.0.as_slice())?)
256 }
257
258 fn serialized_size(&self, _compress: Compress) -> usize {
259 core::mem::size_of::<Address>()
260 }
261}
262impl CanonicalDeserialize for RewardAccountV2 {
263 fn deserialize_with_mode<R: Read>(
264 mut reader: R,
265 _compress: Compress,
266 _validate: Validate,
267 ) -> Result<Self, SerializationError> {
268 let mut bytes = [0u8; core::mem::size_of::<Address>()];
269 reader.read_exact(&mut bytes)?;
270 let value = Address::from_slice(&bytes);
271 Ok(Self(value))
272 }
273}
274
275impl CanonicalSerialize for RewardAccountV1 {
276 fn serialize_with_mode<W: std::io::prelude::Write>(
277 &self,
278 mut writer: W,
279 _compress: Compress,
280 ) -> Result<(), SerializationError> {
281 Ok(writer.write_all(self.0.as_slice())?)
282 }
283
284 fn serialized_size(&self, _compress: Compress) -> usize {
285 core::mem::size_of::<Address>()
286 }
287}
288impl CanonicalDeserialize for RewardAccountV1 {
289 fn deserialize_with_mode<R: Read>(
290 mut reader: R,
291 _compress: Compress,
292 _validate: Validate,
293 ) -> Result<Self, SerializationError> {
294 let mut bytes = [0u8; core::mem::size_of::<Address>()];
295 reader.read_exact(&mut bytes)?;
296 let value = Address::from_slice(&bytes);
297 Ok(Self(value))
298 }
299}
300
301impl From<[u8; 20]> for RewardAccountV2 {
302 fn from(bytes: [u8; 20]) -> Self {
303 Self(Address::from(bytes))
304 }
305}
306
307impl AsRef<[u8]> for RewardAccountV2 {
308 fn as_ref(&self) -> &[u8] {
309 self.0.as_slice()
310 }
311}
312
313impl<const ARITY: usize> ToTraversalPath<ARITY> for RewardAccountV2 {
314 fn to_traversal_path(&self, height: usize) -> Vec<usize> {
315 let mut result = vec![0; height];
316
317 let mut value = U256::from_be_slice(self.0.as_slice());
319
320 for item in result.iter_mut().take(height) {
322 let digit = (value % U256::from(ARITY)).to::<usize>();
323 *item = digit;
324 value /= U256::from(ARITY);
325 }
326
327 result
328 }
329}
330
331impl RewardAccountProofV2 {
332 pub fn presence(
333 pos: FeeAccount,
334 proof: <RewardMerkleTreeV2 as MerkleTreeScheme>::MembershipProof,
335 ) -> Self {
336 Self {
337 account: pos.into(),
338 proof: RewardMerkleProofV2::Presence(proof),
339 }
340 }
341
342 pub fn absence(
343 pos: RewardAccountV2,
344 proof: <RewardMerkleTreeV2 as UniversalMerkleTreeScheme>::NonMembershipProof,
345 ) -> Self {
346 Self {
347 account: pos.into(),
348 proof: RewardMerkleProofV2::Absence(proof),
349 }
350 }
351
352 pub fn prove(tree: &RewardMerkleTreeV2, account: Address) -> Option<(Self, U256)> {
353 match tree.universal_lookup(RewardAccountV2(account)) {
354 LookupResult::Ok(balance, proof) => Some((
355 Self {
356 account,
357 proof: RewardMerkleProofV2::Presence(proof),
358 },
359 balance.0,
360 )),
361 LookupResult::NotFound(proof) => Some((
362 Self {
363 account,
364 proof: RewardMerkleProofV2::Absence(proof),
365 },
366 U256::ZERO,
367 )),
368 LookupResult::NotInMemory => None,
369 }
370 }
371
372 pub fn verify(&self, comm: &RewardMerkleCommitmentV2) -> anyhow::Result<U256> {
373 match &self.proof {
374 RewardMerkleProofV2::Presence(proof) => {
375 ensure!(
376 RewardMerkleTreeV2::verify(comm, RewardAccountV2(self.account), proof)?.is_ok(),
377 "invalid proof"
378 );
379 Ok(proof
380 .elem()
381 .context("presence proof is missing account balance")?
382 .0)
383 },
384 RewardMerkleProofV2::Absence(proof) => {
385 let tree = RewardMerkleTreeV2::from_commitment(comm);
386 ensure!(
387 RewardMerkleTreeV2::non_membership_verify(
388 tree.commitment(),
389 RewardAccountV2(self.account),
390 proof
391 )?,
392 "invalid proof"
393 );
394 Ok(U256::ZERO)
395 },
396 }
397 }
398
399 pub fn remember(&self, tree: &mut RewardMerkleTreeV2) -> anyhow::Result<()> {
400 match &self.proof {
401 RewardMerkleProofV2::Presence(proof) => {
402 tree.remember(
403 RewardAccountV2(self.account),
404 proof
405 .elem()
406 .context("presence proof is missing account balance")?,
407 proof,
408 )?;
409 Ok(())
410 },
411 RewardMerkleProofV2::Absence(proof) => {
412 tree.non_membership_remember(RewardAccountV2(self.account), proof)?;
413 Ok(())
414 },
415 }
416 }
417}
418
419impl TryInto<RewardProofSiblings> for RewardAccountProofV2 {
420 type Error = anyhow::Error;
421
422 fn try_into(self) -> anyhow::Result<RewardProofSiblings> {
427 let proof = if let RewardMerkleProofV2::Presence(proof) = &self.proof {
429 proof
430 } else {
431 bail!("only presence proofs supported")
432 };
433
434 let path = ToTraversalPath::<{ REWARD_MERKLE_TREE_V2_ARITY }>::to_traversal_path(
435 &RewardAccountV2(self.account),
436 REWARD_MERKLE_TREE_V2_HEIGHT,
437 );
438
439 if path.len() != REWARD_MERKLE_TREE_V2_HEIGHT {
440 bail!("Invalid proof: unexpected path length: {}", path.len());
441 };
442
443 let siblings: [B256; REWARD_MERKLE_TREE_V2_HEIGHT] = proof
444 .proof
445 .iter()
446 .enumerate()
447 .skip(1) .filter_map(|(level_idx, node)| match node {
449 MerkleNode::Branch { children, .. } => {
450 let path_direction = path
452 .get(level_idx - 1)
453 .copied()
454 .expect("exists");
455 let sibling_idx = if path_direction == 0 { 1 } else { 0 };
456 if sibling_idx >= children.len() {
457 panic!(
458 "Invalid proof: index={sibling_idx} length={}",
459 children.len()
460 );
461 };
462
463 match children[sibling_idx].as_ref() {
464 MerkleNode::Empty => Some(B256::ZERO),
465 MerkleNode::Leaf { value, .. } => {
466 let bytes = value.as_ref();
467 Some(B256::from_slice(bytes))
468 }
469 MerkleNode::Branch { value, .. } => {
470 let bytes = value.as_ref();
471 Some(B256::from_slice(bytes))
472 }
473 MerkleNode::ForgettenSubtree { value } => {
474 let bytes = value.as_ref();
475 Some(B256::from_slice(bytes))
476 }
477 }
478 }
479 _ => None,
480 })
481 .collect::<Vec<B256>>().try_into().map_err(|err: Vec<_>| {
482 panic!("Invalid proof length: {:?}, this should never happen", err.len())
483 })
484 .unwrap();
485
486 Ok(siblings.into())
487 }
488}
489
490impl RewardAccountProofV1 {
491 pub fn presence(
492 pos: FeeAccount,
493 proof: <RewardMerkleTreeV1 as MerkleTreeScheme>::MembershipProof,
494 ) -> Self {
495 Self {
496 account: pos.into(),
497 proof: RewardMerkleProofV1::Presence(proof),
498 }
499 }
500
501 pub fn absence(
502 pos: RewardAccountV1,
503 proof: <RewardMerkleTreeV1 as UniversalMerkleTreeScheme>::NonMembershipProof,
504 ) -> Self {
505 Self {
506 account: pos.into(),
507 proof: RewardMerkleProofV1::Absence(proof),
508 }
509 }
510
511 pub fn prove(tree: &RewardMerkleTreeV1, account: Address) -> Option<(Self, U256)> {
512 match tree.universal_lookup(RewardAccountV1(account)) {
513 LookupResult::Ok(balance, proof) => Some((
514 Self {
515 account,
516 proof: RewardMerkleProofV1::Presence(proof),
517 },
518 balance.0,
519 )),
520 LookupResult::NotFound(proof) => Some((
521 Self {
522 account,
523 proof: RewardMerkleProofV1::Absence(proof),
524 },
525 U256::ZERO,
526 )),
527 LookupResult::NotInMemory => None,
528 }
529 }
530
531 pub fn verify(&self, comm: &RewardMerkleCommitmentV1) -> anyhow::Result<U256> {
532 match &self.proof {
533 RewardMerkleProofV1::Presence(proof) => {
534 ensure!(
535 RewardMerkleTreeV1::verify(comm, RewardAccountV1(self.account), proof)?.is_ok(),
536 "invalid proof"
537 );
538 Ok(proof
539 .elem()
540 .context("presence proof is missing account balance")?
541 .0)
542 },
543 RewardMerkleProofV1::Absence(proof) => {
544 let tree = RewardMerkleTreeV1::from_commitment(comm);
545 ensure!(
546 RewardMerkleTreeV1::non_membership_verify(
547 tree.commitment(),
548 RewardAccountV1(self.account),
549 proof
550 )?,
551 "invalid proof"
552 );
553 Ok(U256::ZERO)
554 },
555 }
556 }
557
558 pub fn remember(&self, tree: &mut RewardMerkleTreeV1) -> anyhow::Result<()> {
559 match &self.proof {
560 RewardMerkleProofV1::Presence(proof) => {
561 tree.remember(
562 RewardAccountV1(self.account),
563 proof
564 .elem()
565 .context("presence proof is missing account balance")?,
566 proof,
567 )?;
568 Ok(())
569 },
570 RewardMerkleProofV1::Absence(proof) => {
571 tree.non_membership_remember(RewardAccountV1(self.account), proof)?;
572 Ok(())
573 },
574 }
575 }
576}
577
578impl From<(RewardAccountProofV2, U256)> for RewardAccountQueryDataV2 {
579 fn from((proof, balance): (RewardAccountProofV2, U256)) -> Self {
580 Self { balance, proof }
581 }
582}
583
584#[derive(Clone, Debug)]
585pub struct ComputedRewards {
586 leader_address: Address,
587 leader_commission: RewardAmount,
589 delegators: Vec<(Address, RewardAmount)>,
591}
592
593impl ComputedRewards {
594 pub fn new(
595 delegators: Vec<(Address, RewardAmount)>,
596 leader_address: Address,
597 leader_commission: RewardAmount,
598 ) -> Self {
599 Self {
600 delegators,
601 leader_address,
602 leader_commission,
603 }
604 }
605
606 pub fn leader_commission(&self) -> &RewardAmount {
607 &self.leader_commission
608 }
609
610 pub fn delegators(&self) -> &Vec<(Address, RewardAmount)> {
611 &self.delegators
612 }
613
614 pub fn all_rewards(self) -> Vec<(Address, RewardAmount)> {
616 self.delegators
617 .into_iter()
618 .chain(once((self.leader_address, self.leader_commission)))
619 .collect()
620 }
621}
622
623pub struct ValidatorLeaderCounts(Vec<(AuthenticatedValidator<BLSPubKey>, u16)>);
625
626impl ValidatorLeaderCounts {
627 pub fn new(snapshot: &EpochSnapshot, leader_counts: LeaderCounts) -> anyhow::Result<Self> {
634 let entries: Vec<_> = snapshot
635 .stake_table()
636 .zip(leader_counts.iter().copied())
637 .map(|(entry, count)| {
638 let validator = snapshot.validator_config(&entry.stake_table_entry.stake_key)?;
639 Ok((validator.clone(), count))
640 })
641 .collect::<anyhow::Result<_>>()?;
642
643 Ok(Self(entries))
644 }
645
646 pub fn active_leaders(
648 &self,
649 ) -> impl Iterator<Item = (&AuthenticatedValidator<BLSPubKey>, u16)> {
650 self.0
651 .iter()
652 .filter(|(_, count)| *count > 0)
653 .map(|(v, count)| (v, *count))
654 }
655
656 fn all_reward_accounts(&self) -> Vec<RewardAccountV2> {
658 self.active_leaders()
659 .flat_map(|(v, _)| {
660 std::iter::once(RewardAccountV2(v.account))
661 .chain(v.delegators.keys().map(|d| RewardAccountV2(*d)))
662 })
663 .collect()
664 }
665}
666
667pub struct RewardDistributor {
668 validator: AuthenticatedValidator<BLSPubKey>,
669 block_reward: RewardAmount,
670 total_distributed: RewardAmount,
671}
672
673impl RewardDistributor {
674 pub fn new(
675 validator: AuthenticatedValidator<BLSPubKey>,
676 block_reward: RewardAmount,
677 total_distributed: RewardAmount,
678 ) -> Self {
679 Self {
680 validator,
681 block_reward,
682 total_distributed,
683 }
684 }
685
686 pub fn validator(&self) -> AuthenticatedValidator<BLSPubKey> {
687 self.validator.clone()
688 }
689
690 pub fn block_reward(&self) -> RewardAmount {
691 self.block_reward
692 }
693
694 pub fn total_distributed(&self) -> RewardAmount {
695 self.total_distributed
696 }
697
698 pub fn update_rewards_delta(&self, delta: &mut Delta) -> anyhow::Result<()> {
699 delta
701 .rewards_delta
702 .insert(RewardAccountV2(self.validator().account));
703 delta.rewards_delta.extend(
704 self.validator()
705 .delegators
706 .keys()
707 .map(|d| RewardAccountV2(*d)),
708 );
709
710 Ok(())
711 }
712
713 pub fn update_reward_balance<P>(
714 tree: &mut P,
715 account: &P::Index,
716 amount: P::Element,
717 ) -> anyhow::Result<()>
718 where
719 P: UniversalMerkleTreeScheme<Element = RewardAmount>,
720 P::Index: Borrow<<P as MerkleTreeScheme>::Index> + std::fmt::Display,
721 {
722 let mut err = None;
723 tree.update_with(account.clone(), |balance| {
724 let balance = balance.copied();
725 match balance.unwrap_or_default().0.checked_add(amount.0) {
726 Some(updated) => Some(updated.into()),
727 None => {
728 err = Some(format!("overflowed reward balance for account {account}"));
729 balance
730 },
731 }
732 })?;
733
734 if let Some(error) = err {
735 tracing::warn!(error);
736 bail!(error)
737 }
738
739 Ok(())
740 }
741
742 pub fn apply_rewards(
743 &mut self,
744 version: Version,
745 state: &mut ValidatedState,
746 ) -> anyhow::Result<()> {
747 let computed_rewards = self.compute_rewards()?;
748
749 if version <= EPOCH_VERSION {
750 for (address, reward) in computed_rewards.all_rewards() {
751 Self::update_reward_balance(
752 &mut state.reward_merkle_tree_v1,
753 &RewardAccountV1(address),
754 reward,
755 )?;
756 tracing::debug!(%address, %reward, "applied v1 rewards");
757 }
758 } else {
759 for (address, reward) in computed_rewards.all_rewards() {
760 Self::update_reward_balance(
761 &mut state.reward_merkle_tree_v2,
762 &RewardAccountV2(address),
763 reward,
764 )?;
765 tracing::debug!(%address, %reward, "applied v2 rewards");
766 }
767 }
768
769 self.total_distributed += self.block_reward();
770
771 Ok(())
772 }
773
774 pub fn compute_rewards(&self) -> anyhow::Result<ComputedRewards> {
782 ensure!(
783 self.validator.commission <= COMMISSION_BASIS_POINTS,
784 "commission must not exceed {COMMISSION_BASIS_POINTS}"
785 );
786
787 let mut rewards = Vec::new();
788
789 let total_reward = self.block_reward.0;
790 let delegators_ratio_basis_points = U256::from(COMMISSION_BASIS_POINTS)
791 .checked_sub(U256::from(self.validator.commission))
792 .context("overflow")?;
793 let delegators_reward = delegators_ratio_basis_points
794 .checked_mul(total_reward)
795 .context("overflow")?;
796
797 let total_stake = self.validator.stake;
799 let mut delegators_total_reward_distributed = U256::from(0);
800 for (delegator_address, delegator_stake) in &self.validator.delegators {
801 let delegator_reward = RewardAmount::from(
802 (delegator_stake
803 .checked_mul(delegators_reward)
804 .context("overflow")?
805 .checked_div(total_stake)
806 .context("overflow")?)
807 .checked_div(U256::from(COMMISSION_BASIS_POINTS))
808 .context("overflow")?,
809 );
810
811 delegators_total_reward_distributed += delegator_reward.0;
812
813 rewards.push((*delegator_address, delegator_reward));
814 }
815
816 let leader_commission = total_reward
817 .checked_sub(delegators_total_reward_distributed)
818 .context("overflow")?;
819
820 Ok(ComputedRewards::new(
821 rewards,
822 self.validator.account,
823 leader_commission.into(),
824 ))
825 }
826}
827
828pub async fn distribute_block_reward(
835 instance_state: &NodeState,
836 validated_state: &mut ValidatedState,
837 parent_leaf: &Leaf2,
838 view_number: ViewNumber,
839 version: Version,
840) -> anyhow::Result<Option<RewardDistributor>> {
841 let height = parent_leaf.height() + 1;
842
843 let epoch_height = instance_state
844 .epoch_height
845 .context("epoch height not found")?;
846 let epoch = EpochNumber::new(epoch_from_block_number(height, epoch_height));
847 let coordinator = instance_state.coordinator.clone();
848 let first_epoch = {
849 coordinator
850 .membership()
851 .first_epoch()
852 .context("The first epoch was not set.")?
853 };
854
855 if epoch <= first_epoch + 1 {
858 return Ok(None);
859 }
860
861 let leader = get_leader_and_fetch_missing_rewards(
865 instance_state,
866 validated_state,
867 parent_leaf,
868 view_number,
869 )
870 .await?;
871
872 let parent_header = parent_leaf.block_header();
873
874 let mut previously_distributed = parent_header.total_reward_distributed().unwrap_or_default();
876
877 let block_reward = if version == DRB_AND_HEADER_UPGRADE_VERSION {
879 instance_state
880 .block_reward(EpochNumber::new(*epoch))
881 .await
882 .with_context(|| format!("block reward is None for epoch {epoch}"))?
883 } else {
884 instance_state.fixed_block_reward().await?
885 };
886
887 if version == DRB_AND_HEADER_UPGRADE_VERSION && parent_header.version() == EPOCH_VERSION {
892 ensure!(
893 instance_state.epoch_start_block != 0,
894 "epoch_start_block is zero"
895 );
896
897 let fixed_block_reward = instance_state.fixed_block_reward().await?;
898
899 let first_reward_block = (*first_epoch + 1) * epoch_height + 1;
905 if height > first_reward_block {
909 let blocks = height.checked_sub(first_reward_block).with_context(|| {
912 format!("height ({height}) - first_reward_block ({first_reward_block}) underflowed")
913 })?;
914 previously_distributed = U256::from(blocks)
915 .checked_mul(fixed_block_reward.0)
916 .with_context(|| {
917 format!(
918 "overflow during total_distributed calculation: blocks={blocks}, \
919 fixed_block_reward={}",
920 fixed_block_reward.0
921 )
922 })?
923 .into();
924 }
925 }
926
927 if block_reward.0.is_zero() {
928 tracing::info!("block reward is zero. height={height}. epoch={epoch}");
929 return Ok(None);
930 }
931
932 let mut reward_distributor =
933 RewardDistributor::new(leader, block_reward, previously_distributed);
934
935 reward_distributor.apply_rewards(version, validated_state)?;
936
937 Ok(Some(reward_distributor))
938}
939
940pub async fn get_leader_and_fetch_missing_rewards(
941 instance_state: &NodeState,
942 validated_state: &mut ValidatedState,
943 parent_leaf: &Leaf2,
944 view: ViewNumber,
945) -> anyhow::Result<AuthenticatedValidator<BLSPubKey>> {
946 let parent_height = parent_leaf.height();
947 let parent_view = parent_leaf.view_number();
948 let new_height = parent_height + 1;
949
950 let epoch_height = instance_state
951 .epoch_height
952 .context("epoch height not found")?;
953 if epoch_height == 0 {
954 bail!("epoch height is 0. can not catchup reward accounts");
955 }
956 let epoch = EpochNumber::new(epoch_from_block_number(new_height, epoch_height));
957
958 let coordinator = instance_state.coordinator.clone();
959
960 let membership = coordinator.membership_for_epoch(Some(epoch))?;
961
962 let snapshot = membership
963 .snapshot()
964 .context(format!("no committee for epoch {epoch:?}"))?;
965
966 let leader: BLSPubKey = snapshot
967 .leader(view)
968 .context(format!("leader for epoch {epoch:?} not found"))?;
969
970 tracing::debug!("Selected leader: {leader} for view {view} and epoch {epoch}");
971
972 let validator = snapshot
973 .validator_config(&leader)
974 .context("validator not found")?;
975
976 let parent_header = parent_leaf.block_header();
977
978 if parent_header.version() <= EPOCH_VERSION {
979 let mut reward_accounts = HashSet::new();
980 reward_accounts.insert(validator.account.into());
981 let delegators = validator
982 .delegators
983 .keys()
984 .cloned()
985 .map(|a| a.into())
986 .collect::<Vec<RewardAccountV2>>();
987
988 reward_accounts.extend(delegators.clone());
989
990 let accts: HashSet<_> = reward_accounts
991 .into_iter()
992 .map(RewardAccountV1::from)
993 .collect();
994 let missing_reward_accts = validated_state.forgotten_reward_accounts_v1(accts);
995
996 if !missing_reward_accts.is_empty() {
997 tracing::warn!(
998 parent_height,
999 ?parent_view,
1000 ?missing_reward_accts,
1001 "fetching missing v1 reward accounts from peers"
1002 );
1003
1004 let missing_account_proofs = instance_state
1005 .state_catchup
1006 .fetch_reward_accounts_v1(
1007 instance_state,
1008 parent_height,
1009 parent_view,
1010 validated_state.reward_merkle_tree_v1.commitment(),
1011 missing_reward_accts,
1012 )
1013 .await?;
1014
1015 for proof in missing_account_proofs.iter() {
1016 proof
1017 .remember(&mut validated_state.reward_merkle_tree_v1)
1018 .expect("proof previously verified");
1019 }
1020 }
1021 } else {
1022 let reward_accounts = Arc::new(
1023 std::iter::once(validator.account.into())
1024 .chain(validator.delegators.keys().cloned().map(Into::into))
1025 .collect::<Vec<_>>(),
1026 );
1027
1028 let reward_merkle_tree_root = validated_state.reward_merkle_tree_v2.commitment();
1029 if forgotten_accounts_include(&validated_state.reward_merkle_tree_v2, &reward_accounts) {
1030 tracing::warn!(
1031 parent_height,
1032 ?parent_view,
1033 %reward_merkle_tree_root,
1034 "fetching reward merkle tree from peers"
1035 );
1036
1037 validated_state.reward_merkle_tree_v2 = instance_state
1038 .state_catchup
1039 .fetch_reward_merkle_tree_v2(
1040 parent_height,
1041 parent_view,
1042 reward_merkle_tree_root,
1043 reward_accounts,
1044 )
1045 .await?
1046 .tree;
1047
1048 tracing::warn!(
1049 parent_height,
1050 ?parent_view,
1051 %reward_merkle_tree_root,
1052 "successfully fetched reward merkle tree from peers"
1053 );
1054 }
1055 }
1056
1057 Ok(validator.clone())
1058}
1059
1060#[derive(Debug, Clone)]
1062pub struct EpochRewardsResult {
1063 pub epoch: EpochNumber,
1065 pub reward_tree: RewardMerkleTreeV2,
1067 pub total_distributed: RewardAmount,
1069 pub changed_accounts: HashSet<RewardAccountV2>,
1071}
1072
1073#[derive(Debug, Default)]
1076pub struct EpochRewardsCalculator {
1077 pending: Option<(EpochNumber, JoinHandle<anyhow::Result<EpochRewardsResult>>)>,
1079}
1080
1081impl EpochRewardsCalculator {
1082 pub fn new() -> Self {
1083 Self { pending: None }
1084 }
1085
1086 pub fn is_calculating(&self, epoch: EpochNumber) -> bool {
1088 self.pending.as_ref().is_some_and(|(e, _)| *e == epoch)
1089 }
1090
1091 pub async fn get_result(&mut self, epoch: EpochNumber) -> Option<EpochRewardsResult> {
1098 let (pending_epoch, handle) = self.pending.take()?;
1099 if pending_epoch != epoch {
1100 self.pending = Some((pending_epoch, handle));
1102 return None;
1103 }
1104
1105 match handle.await {
1106 Ok(Ok(result)) => {
1107 tracing::info!(%epoch, total = %result.total_distributed.0, "epoch rewards calculation completed");
1108 Some(result)
1109 },
1110 Ok(Err(e)) => {
1111 tracing::error!(%epoch, error = %e, "epoch rewards calculation failed");
1112 None
1113 },
1114 Err(e) => {
1115 tracing::error!(%epoch, error = %e, "epoch rewards task panicked");
1116 None
1117 },
1118 }
1119 }
1120
1121 pub fn spawn_background_task(
1125 &mut self,
1126 epoch: EpochNumber,
1127 epoch_height: u64,
1128 reward_tree: RewardMerkleTreeV2,
1129 instance_state: NodeState,
1130 coordinator: EpochMembershipCoordinator<SeqTypes>,
1131 leader_counts: Option<LeaderCounts>,
1132 ) {
1133 if self.is_calculating(epoch) {
1134 tracing::debug!(%epoch, "calculation already in progress, skipping");
1135 return;
1136 }
1137
1138 if let Some((stale_epoch, handle)) = self.pending.take() {
1140 tracing::info!(%stale_epoch, %epoch, "aborting stale epoch rewards task");
1141 handle.abort();
1142 }
1143
1144 tracing::info!(
1145 %epoch,
1146 has_leader_counts = leader_counts.is_some(),
1147 "starting background epoch rewards task"
1148 );
1149
1150 let handle = tokio::spawn(async move {
1151 Self::fetch_and_calculate(
1152 epoch,
1153 epoch_height,
1154 reward_tree,
1155 instance_state,
1156 coordinator,
1157 leader_counts,
1158 )
1159 .await
1160 });
1161 self.pending = Some((epoch, handle));
1162 }
1163
1164 async fn fetch_and_calculate(
1165 epoch: EpochNumber,
1166 epoch_height: u64,
1167 mut reward_tree: RewardMerkleTreeV2,
1168 instance_state: NodeState,
1169 coordinator: EpochMembershipCoordinator<SeqTypes>,
1170 leader_counts: Option<LeaderCounts>,
1171 ) -> anyhow::Result<EpochRewardsResult> {
1172 let epoch_last_block_height = (*epoch) * epoch_height;
1173
1174 tracing::info!(
1175 %epoch,
1176 epoch_last_block_height,
1177 has_leader_counts = leader_counts.is_some(),
1178 "fetch_and_calculate: starting"
1179 );
1180
1181 if let Err(err) = coordinator.membership_for_epoch(Some(epoch)) {
1183 tracing::info!(%epoch, "stake table missing for epoch, triggering catchup: {err:#}");
1184 coordinator
1185 .wait_for_catchup(epoch)
1186 .await
1187 .context(format!("failed to catch up for epoch={epoch}"))?;
1188 }
1189
1190 let leader_counts = if let Some(lc) = leader_counts {
1192 lc
1193 } else {
1194 let snapshot = coordinator
1197 .membership()
1198 .snapshot(epoch)
1199 .context(format!("no committee for epoch={epoch}"))?;
1200 let stake_table = HSStakeTable::from_iter(snapshot.stake_table());
1201 let success_threshold = snapshot.success_threshold();
1202
1203 let leaf = instance_state
1204 .state_catchup
1205 .as_ref()
1206 .fetch_leaf(epoch_last_block_height, stake_table, success_threshold)
1207 .await
1208 .with_context(|| {
1209 format!(
1210 "failed to fetch leaf at height {epoch_last_block_height} for epoch \
1211 {epoch}"
1212 )
1213 })?;
1214 let header = leaf.block_header();
1215
1216 tracing::info!(
1217 %epoch,
1218 header_height = header.height(),
1219 header_version = %header.version(),
1220 header_reward_merkle_tree_root = %header.reward_merkle_tree_root(),
1221 "fetch_and_calculate: fetched leaf"
1222 );
1223
1224 anyhow::ensure!(
1225 header.version() >= EPOCH_REWARD_VERSION,
1226 "header version {} is pre-V6, cannot calculate rewards",
1227 header.version()
1228 );
1229
1230 let expected_root = header.reward_merkle_tree_root().right();
1236 let actual_root = reward_tree.commitment();
1237 if expected_root != Some(actual_root) {
1238 tracing::warn!(
1239 %epoch,
1240 ?expected_root,
1241 ?actual_root,
1242 "reward merkle tree root mismatch, using empty tree"
1243 );
1244 reward_tree = RewardMerkleTreeV2::new(REWARD_MERKLE_TREE_V2_HEIGHT);
1245 }
1246
1247 *header
1248 .leader_counts()
1249 .expect("V6+ header must have leader_counts")
1250 };
1251
1252 let snapshot = coordinator
1253 .membership()
1254 .snapshot(epoch)
1255 .with_context(|| format!("no committee for epoch={epoch}"))?;
1256 let validator_leader_counts = ValidatorLeaderCounts::new(&snapshot, leader_counts)?;
1257 let block_reward = snapshot
1258 .epoch_block_reward()
1259 .context("block reward not found for epoch")?;
1260
1261 tracing::info!(
1262 %epoch,
1263 %block_reward,
1264 "fetch_and_calculate: got block_reward"
1265 );
1266
1267 let accounts_to_update = validator_leader_counts.all_reward_accounts();
1269
1270 let missing_accounts: Vec<_> = accounts_to_update
1271 .iter()
1272 .filter(|account| reward_tree.lookup(**account).expect_not_in_memory().is_ok())
1273 .cloned()
1274 .collect();
1275
1276 if !missing_accounts.is_empty() {
1279 tracing::info!(
1280 %epoch,
1281 num_missing = missing_accounts.len(),
1282 "missing accounts detected, fetching reward merkle tree from peers"
1283 );
1284
1285 let reward_merkle_tree_root = reward_tree.commitment();
1286 reward_tree = instance_state
1287 .state_catchup
1288 .as_ref()
1289 .fetch_reward_merkle_tree_v2(
1290 epoch_last_block_height,
1291 ViewNumber::new(0),
1292 reward_merkle_tree_root,
1293 Arc::new(missing_accounts),
1294 )
1295 .await
1296 .with_context(|| {
1297 format!(
1298 "failed to fetch reward merkle tree at height {epoch_last_block_height} \
1299 for epoch {epoch}"
1300 )
1301 })?
1302 .tree;
1303
1304 tracing::info!(
1305 %epoch,
1306 reward_tree_commitment = %reward_tree.commitment(),
1307 "reward tree fetched successfully"
1308 );
1309 }
1310
1311 tracing::info!(
1312 %epoch,
1313 reward_tree_commitment = %reward_tree.commitment(),
1314 "starting final epoch calculation"
1315 );
1316
1317 Self::calculate_all_rewards(epoch, validator_leader_counts, reward_tree, block_reward).await
1318 }
1319
1320 async fn calculate_all_rewards(
1322 epoch: EpochNumber,
1323 validator_leader_counts: ValidatorLeaderCounts,
1324 mut reward_tree: RewardMerkleTreeV2,
1325 block_reward: RewardAmount,
1326 ) -> anyhow::Result<EpochRewardsResult> {
1327 let mut total_distributed = U256::ZERO;
1328 let mut changed_accounts = HashSet::new();
1329
1330 for (validator, count) in validator_leader_counts.active_leaders() {
1331 let validator_reward = block_reward
1333 .0
1334 .checked_mul(U256::from(count))
1335 .context("overflow in validator reward calculation")?;
1336
1337 if validator_reward.is_zero() {
1338 continue;
1339 }
1340
1341 changed_accounts.insert(RewardAccountV2(validator.account));
1342 changed_accounts.extend(validator.delegators.keys().map(|d| RewardAccountV2(*d)));
1343
1344 let distributor = RewardDistributor::new(
1345 validator.clone(),
1346 RewardAmount(validator_reward),
1347 Default::default(),
1348 );
1349
1350 let computed_rewards = distributor.compute_rewards()?;
1351
1352 for (address, reward) in computed_rewards.all_rewards() {
1353 RewardDistributor::update_reward_balance(
1354 &mut reward_tree,
1355 &RewardAccountV2(address),
1356 reward,
1357 )?;
1358 tracing::debug!(%epoch, %address, %reward, "applied epoch reward");
1359 }
1360
1361 total_distributed += validator_reward;
1362 }
1363
1364 tracing::info!(
1365 %epoch,
1366 total_distributed = %total_distributed,
1367 num_changed_accounts = changed_accounts.len(),
1368 "epoch rewards calculation complete"
1369 );
1370
1371 Ok(EpochRewardsResult {
1372 epoch,
1373 reward_tree,
1374 total_distributed: RewardAmount(total_distributed),
1375 changed_accounts,
1376 })
1377 }
1378}
1379
1380#[cfg(test)]
1381pub mod tests {
1382
1383 use super::*;
1384
1385 fn make_distributor(commission: u16) -> RewardDistributor {
1386 RewardDistributor::new(
1387 AuthenticatedValidator::mock_with_commission(commission),
1388 RewardAmount(U256::from(1902000000000000000_u128)),
1389 U256::ZERO.into(),
1390 )
1391 }
1392
1393 fn total_rewards(rewards: ComputedRewards) -> U256 {
1394 rewards
1395 .all_rewards()
1396 .iter()
1397 .fold(U256::ZERO, |acc, (_, r)| acc + r.0)
1398 }
1399
1400 #[test]
1403 fn test_reward_calculation_sanity_checks() {
1404 let distributor = make_distributor(500);
1409 let rewards = distributor.compute_rewards().unwrap();
1410 assert_eq!(total_rewards(rewards.clone()), distributor.block_reward.0);
1411
1412 let distributor = make_distributor(0);
1413 let rewards = distributor.compute_rewards().unwrap();
1414 assert_eq!(total_rewards(rewards.clone()), distributor.block_reward.0);
1415
1416 let distributor = make_distributor(10000);
1417 let rewards = distributor.compute_rewards().unwrap();
1418 assert_eq!(total_rewards(rewards.clone()), distributor.block_reward.0);
1419 let leader_commission = rewards.leader_commission();
1420 assert_eq!(*leader_commission, distributor.block_reward);
1421
1422 let distributor = make_distributor(10001);
1423 assert!(
1424 distributor
1425 .compute_rewards()
1426 .err()
1427 .unwrap()
1428 .to_string()
1429 .contains("must not exceed")
1430 );
1431 }
1432
1433 #[test]
1434 fn test_compute_rewards_validator_commission() {
1435 let distributor = make_distributor(0);
1436 let rewards = distributor.compute_rewards().unwrap();
1437 let leader_commission = rewards.leader_commission();
1438 let percentage =
1439 leader_commission.0 * U256::from(COMMISSION_BASIS_POINTS) / distributor.block_reward.0;
1440 assert_eq!(percentage, U256::ZERO);
1441
1442 let distributor = make_distributor(300);
1444 let rewards = distributor.compute_rewards().unwrap();
1445 let leader_commission = rewards.leader_commission();
1446 let percentage =
1447 leader_commission.0 * U256::from(COMMISSION_BASIS_POINTS) / distributor.block_reward.0;
1448 println!("percentage: {percentage:?}");
1449 assert_eq!(percentage, U256::from(300));
1450
1451 let distributor = make_distributor(10000);
1453 let rewards = distributor.compute_rewards().unwrap();
1454 let leader_commission = rewards.leader_commission();
1455 assert_eq!(*leader_commission, distributor.block_reward);
1456 }
1457}