1use std::{borrow::Borrow, collections::HashSet, iter::once, str::FromStr, sync::Arc};
2
3use alloy::primitives::{
4 Address, B256, U256,
5 utils::{ParseUnits, parse_units},
6};
7use anyhow::{Context, bail, ensure};
8use ark_serialize::{
9 CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate,
10};
11use espresso_utils::{
12 impl_serde_from_string_or_integer, impl_to_fixed_bytes, ser::FromStringOrInteger,
13};
14use hotshot::types::BLSPubKey;
15use hotshot_contract_adapter::reward::RewardProofSiblings;
16use hotshot_types::{
17 data::{EpochNumber, ViewNumber},
18 epoch_membership::EpochMembershipCoordinator,
19 traits::election::Membership,
20 utils::epoch_from_block_number,
21};
22use jf_merkle_tree_compat::{
23 ForgetableMerkleTreeScheme, ForgetableUniversalMerkleTreeScheme, LookupResult,
24 MerkleTreeScheme, ToTraversalPath, UniversalMerkleTreeScheme, prelude::MerkleNode,
25};
26use num_traits::CheckedSub;
27use tokio::task::JoinHandle;
28use vbs::version::Version;
29use versions::{DRB_AND_HEADER_UPGRADE_VERSION, EPOCH_REWARD_VERSION, EPOCH_VERSION};
30
31use super::{
32 Leaf2, NodeState, ValidatedState,
33 v0_3::{AuthenticatedValidator, COMMISSION_BASIS_POINTS, RewardAmount},
34 v0_4::{
35 RewardAccountProofV2, RewardAccountQueryDataV2, RewardAccountV2, RewardMerkleCommitmentV2,
36 RewardMerkleProofV2, RewardMerkleTreeV2, forgotten_accounts_include,
37 },
38};
39use crate::{
40 EpochCommittees, FeeAccount, SeqTypes,
41 eth_signature_key::EthKeyPair,
42 v0_3::{
43 RewardAccountProofV1, RewardAccountV1, RewardMerkleCommitmentV1, RewardMerkleProofV1,
44 RewardMerkleTreeV1,
45 },
46 v0_4::{Delta, REWARD_MERKLE_TREE_V2_ARITY, REWARD_MERKLE_TREE_V2_HEIGHT},
47 v0_5::LeaderCounts,
48};
49
50impl_serde_from_string_or_integer!(RewardAmount);
51impl_to_fixed_bytes!(RewardAmount, U256);
52
53impl From<u64> for RewardAmount {
54 fn from(amt: u64) -> Self {
55 Self(U256::from(amt))
56 }
57}
58
59impl CheckedSub for RewardAmount {
60 fn checked_sub(&self, v: &Self) -> Option<Self> {
61 self.0.checked_sub(v.0).map(RewardAmount)
62 }
63}
64
65impl FromStr for RewardAmount {
66 type Err = <U256 as FromStr>::Err;
67
68 fn from_str(s: &str) -> Result<Self, Self::Err> {
69 Ok(Self(s.parse()?))
70 }
71}
72
73impl FromStringOrInteger for RewardAmount {
74 type Binary = U256;
75 type Integer = u64;
76
77 fn from_binary(b: Self::Binary) -> anyhow::Result<Self> {
78 Ok(Self(b))
79 }
80
81 fn from_integer(i: Self::Integer) -> anyhow::Result<Self> {
82 Ok(i.into())
83 }
84
85 fn from_string(s: String) -> anyhow::Result<Self> {
86 if let Some(s) = s.strip_prefix("0x") {
89 return Ok(Self(s.parse()?));
90 }
91
92 let (base, unit) = s
94 .split_once(char::is_whitespace)
95 .unwrap_or((s.as_str(), "wei"));
96 match parse_units(base, unit)? {
97 ParseUnits::U256(n) => Ok(Self(n)),
98 ParseUnits::I256(_) => bail!("amount cannot be negative"),
99 }
100 }
101
102 fn to_binary(&self) -> anyhow::Result<Self::Binary> {
103 Ok(self.0)
104 }
105
106 fn to_string(&self) -> anyhow::Result<String> {
107 Ok(format!("{self}"))
108 }
109}
110
111impl RewardAmount {
112 pub fn as_u64(&self) -> Option<u64> {
113 if self.0 <= U256::from(u64::MAX) {
114 Some(self.0.to::<u64>())
115 } else {
116 None
117 }
118 }
119}
120
121impl From<[u8; 20]> for RewardAccountV1 {
122 fn from(bytes: [u8; 20]) -> Self {
123 Self(Address::from(bytes))
124 }
125}
126
127impl AsRef<[u8]> for RewardAccountV1 {
128 fn as_ref(&self) -> &[u8] {
129 self.0.as_slice()
130 }
131}
132
133impl<const ARITY: usize> ToTraversalPath<ARITY> for RewardAccountV1 {
134 fn to_traversal_path(&self, height: usize) -> Vec<usize> {
135 self.0
136 .as_slice()
137 .iter()
138 .take(height)
139 .map(|i| *i as usize)
140 .collect()
141 }
142}
143
144impl RewardAccountV2 {
145 pub fn address(&self) -> Address {
147 self.0
148 }
149 pub fn as_bytes(&self) -> &[u8] {
151 self.0.as_slice()
152 }
153 pub fn to_fixed_bytes(self) -> [u8; 20] {
155 self.0.into_array()
156 }
157 pub fn test_key_pair() -> EthKeyPair {
158 EthKeyPair::from_mnemonic(
159 "test test test test test test test test test test test junk",
160 0u32,
161 )
162 .unwrap()
163 }
164}
165
166impl RewardAccountV1 {
167 pub fn address(&self) -> Address {
169 self.0
170 }
171 pub fn as_bytes(&self) -> &[u8] {
173 self.0.as_slice()
174 }
175 pub fn to_fixed_bytes(self) -> [u8; 20] {
177 self.0.into_array()
178 }
179 pub fn test_key_pair() -> EthKeyPair {
180 EthKeyPair::from_mnemonic(
181 "test test test test test test test test test test test junk",
182 0u32,
183 )
184 .unwrap()
185 }
186}
187
188impl FromStr for RewardAccountV2 {
189 type Err = anyhow::Error;
190
191 fn from_str(s: &str) -> Result<Self, Self::Err> {
192 Ok(Self(s.parse()?))
193 }
194}
195
196impl FromStr for RewardAccountV1 {
197 type Err = anyhow::Error;
198
199 fn from_str(s: &str) -> Result<Self, Self::Err> {
200 Ok(Self(s.parse()?))
201 }
202}
203
204impl Valid for RewardAmount {
205 fn check(&self) -> Result<(), SerializationError> {
206 Ok(())
207 }
208}
209
210impl Valid for RewardAccountV2 {
211 fn check(&self) -> Result<(), SerializationError> {
212 Ok(())
213 }
214}
215
216impl Valid for RewardAccountV1 {
217 fn check(&self) -> Result<(), SerializationError> {
218 Ok(())
219 }
220}
221
222impl CanonicalSerialize for RewardAmount {
223 fn serialize_with_mode<W: std::io::prelude::Write>(
224 &self,
225 mut writer: W,
226 _compress: Compress,
227 ) -> Result<(), SerializationError> {
228 Ok(writer.write_all(&self.to_fixed_bytes())?)
229 }
230
231 fn serialized_size(&self, _compress: Compress) -> usize {
232 core::mem::size_of::<U256>()
233 }
234}
235impl CanonicalDeserialize for RewardAmount {
236 fn deserialize_with_mode<R: Read>(
237 mut reader: R,
238 _compress: Compress,
239 _validate: Validate,
240 ) -> Result<Self, SerializationError> {
241 let mut bytes = [0u8; core::mem::size_of::<U256>()];
242 reader.read_exact(&mut bytes)?;
243 let value = U256::from_le_slice(&bytes);
244 Ok(Self(value))
245 }
246}
247
248impl CanonicalSerialize for RewardAccountV2 {
249 fn serialize_with_mode<W: std::io::prelude::Write>(
250 &self,
251 mut writer: W,
252 _compress: Compress,
253 ) -> Result<(), SerializationError> {
254 Ok(writer.write_all(self.0.as_slice())?)
255 }
256
257 fn serialized_size(&self, _compress: Compress) -> usize {
258 core::mem::size_of::<Address>()
259 }
260}
261impl CanonicalDeserialize for RewardAccountV2 {
262 fn deserialize_with_mode<R: Read>(
263 mut reader: R,
264 _compress: Compress,
265 _validate: Validate,
266 ) -> Result<Self, SerializationError> {
267 let mut bytes = [0u8; core::mem::size_of::<Address>()];
268 reader.read_exact(&mut bytes)?;
269 let value = Address::from_slice(&bytes);
270 Ok(Self(value))
271 }
272}
273
274impl CanonicalSerialize for RewardAccountV1 {
275 fn serialize_with_mode<W: std::io::prelude::Write>(
276 &self,
277 mut writer: W,
278 _compress: Compress,
279 ) -> Result<(), SerializationError> {
280 Ok(writer.write_all(self.0.as_slice())?)
281 }
282
283 fn serialized_size(&self, _compress: Compress) -> usize {
284 core::mem::size_of::<Address>()
285 }
286}
287impl CanonicalDeserialize for RewardAccountV1 {
288 fn deserialize_with_mode<R: Read>(
289 mut reader: R,
290 _compress: Compress,
291 _validate: Validate,
292 ) -> Result<Self, SerializationError> {
293 let mut bytes = [0u8; core::mem::size_of::<Address>()];
294 reader.read_exact(&mut bytes)?;
295 let value = Address::from_slice(&bytes);
296 Ok(Self(value))
297 }
298}
299
300impl From<[u8; 20]> for RewardAccountV2 {
301 fn from(bytes: [u8; 20]) -> Self {
302 Self(Address::from(bytes))
303 }
304}
305
306impl AsRef<[u8]> for RewardAccountV2 {
307 fn as_ref(&self) -> &[u8] {
308 self.0.as_slice()
309 }
310}
311
312impl<const ARITY: usize> ToTraversalPath<ARITY> for RewardAccountV2 {
313 fn to_traversal_path(&self, height: usize) -> Vec<usize> {
314 let mut result = vec![0; height];
315
316 let mut value = U256::from_be_slice(self.0.as_slice());
318
319 for item in result.iter_mut().take(height) {
321 let digit = (value % U256::from(ARITY)).to::<usize>();
322 *item = digit;
323 value /= U256::from(ARITY);
324 }
325
326 result
327 }
328}
329
330impl RewardAccountProofV2 {
331 pub fn presence(
332 pos: FeeAccount,
333 proof: <RewardMerkleTreeV2 as MerkleTreeScheme>::MembershipProof,
334 ) -> Self {
335 Self {
336 account: pos.into(),
337 proof: RewardMerkleProofV2::Presence(proof),
338 }
339 }
340
341 pub fn absence(
342 pos: RewardAccountV2,
343 proof: <RewardMerkleTreeV2 as UniversalMerkleTreeScheme>::NonMembershipProof,
344 ) -> Self {
345 Self {
346 account: pos.into(),
347 proof: RewardMerkleProofV2::Absence(proof),
348 }
349 }
350
351 pub fn prove(tree: &RewardMerkleTreeV2, account: Address) -> Option<(Self, U256)> {
352 match tree.universal_lookup(RewardAccountV2(account)) {
353 LookupResult::Ok(balance, proof) => Some((
354 Self {
355 account,
356 proof: RewardMerkleProofV2::Presence(proof),
357 },
358 balance.0,
359 )),
360 LookupResult::NotFound(proof) => Some((
361 Self {
362 account,
363 proof: RewardMerkleProofV2::Absence(proof),
364 },
365 U256::ZERO,
366 )),
367 LookupResult::NotInMemory => None,
368 }
369 }
370
371 pub fn verify(&self, comm: &RewardMerkleCommitmentV2) -> anyhow::Result<U256> {
372 match &self.proof {
373 RewardMerkleProofV2::Presence(proof) => {
374 ensure!(
375 RewardMerkleTreeV2::verify(comm, RewardAccountV2(self.account), proof)?.is_ok(),
376 "invalid proof"
377 );
378 Ok(proof
379 .elem()
380 .context("presence proof is missing account balance")?
381 .0)
382 },
383 RewardMerkleProofV2::Absence(proof) => {
384 let tree = RewardMerkleTreeV2::from_commitment(comm);
385 ensure!(
386 RewardMerkleTreeV2::non_membership_verify(
387 tree.commitment(),
388 RewardAccountV2(self.account),
389 proof
390 )?,
391 "invalid proof"
392 );
393 Ok(U256::ZERO)
394 },
395 }
396 }
397
398 pub fn remember(&self, tree: &mut RewardMerkleTreeV2) -> anyhow::Result<()> {
399 match &self.proof {
400 RewardMerkleProofV2::Presence(proof) => {
401 tree.remember(
402 RewardAccountV2(self.account),
403 proof
404 .elem()
405 .context("presence proof is missing account balance")?,
406 proof,
407 )?;
408 Ok(())
409 },
410 RewardMerkleProofV2::Absence(proof) => {
411 tree.non_membership_remember(RewardAccountV2(self.account), proof)?;
412 Ok(())
413 },
414 }
415 }
416}
417
418impl TryInto<RewardProofSiblings> for RewardAccountProofV2 {
419 type Error = anyhow::Error;
420
421 fn try_into(self) -> anyhow::Result<RewardProofSiblings> {
426 let proof = if let RewardMerkleProofV2::Presence(proof) = &self.proof {
428 proof
429 } else {
430 bail!("only presence proofs supported")
431 };
432
433 let path = ToTraversalPath::<{ REWARD_MERKLE_TREE_V2_ARITY }>::to_traversal_path(
434 &RewardAccountV2(self.account),
435 REWARD_MERKLE_TREE_V2_HEIGHT,
436 );
437
438 if path.len() != REWARD_MERKLE_TREE_V2_HEIGHT {
439 bail!("Invalid proof: unexpected path length: {}", path.len());
440 };
441
442 let siblings: [B256; REWARD_MERKLE_TREE_V2_HEIGHT] = proof
443 .proof
444 .iter()
445 .enumerate()
446 .skip(1) .filter_map(|(level_idx, node)| match node {
448 MerkleNode::Branch { children, .. } => {
449 let path_direction = path
451 .get(level_idx - 1)
452 .copied()
453 .expect("exists");
454 let sibling_idx = if path_direction == 0 { 1 } else { 0 };
455 if sibling_idx >= children.len() {
456 panic!(
457 "Invalid proof: index={sibling_idx} length={}",
458 children.len()
459 );
460 };
461
462 match children[sibling_idx].as_ref() {
463 MerkleNode::Empty => Some(B256::ZERO),
464 MerkleNode::Leaf { value, .. } => {
465 let bytes = value.as_ref();
466 Some(B256::from_slice(bytes))
467 }
468 MerkleNode::Branch { value, .. } => {
469 let bytes = value.as_ref();
470 Some(B256::from_slice(bytes))
471 }
472 MerkleNode::ForgettenSubtree { value } => {
473 let bytes = value.as_ref();
474 Some(B256::from_slice(bytes))
475 }
476 }
477 }
478 _ => None,
479 })
480 .collect::<Vec<B256>>().try_into().map_err(|err: Vec<_>| {
481 panic!("Invalid proof length: {:?}, this should never happen", err.len())
482 })
483 .unwrap();
484
485 Ok(siblings.into())
486 }
487}
488
489impl RewardAccountProofV1 {
490 pub fn presence(
491 pos: FeeAccount,
492 proof: <RewardMerkleTreeV1 as MerkleTreeScheme>::MembershipProof,
493 ) -> Self {
494 Self {
495 account: pos.into(),
496 proof: RewardMerkleProofV1::Presence(proof),
497 }
498 }
499
500 pub fn absence(
501 pos: RewardAccountV1,
502 proof: <RewardMerkleTreeV1 as UniversalMerkleTreeScheme>::NonMembershipProof,
503 ) -> Self {
504 Self {
505 account: pos.into(),
506 proof: RewardMerkleProofV1::Absence(proof),
507 }
508 }
509
510 pub fn prove(tree: &RewardMerkleTreeV1, account: Address) -> Option<(Self, U256)> {
511 match tree.universal_lookup(RewardAccountV1(account)) {
512 LookupResult::Ok(balance, proof) => Some((
513 Self {
514 account,
515 proof: RewardMerkleProofV1::Presence(proof),
516 },
517 balance.0,
518 )),
519 LookupResult::NotFound(proof) => Some((
520 Self {
521 account,
522 proof: RewardMerkleProofV1::Absence(proof),
523 },
524 U256::ZERO,
525 )),
526 LookupResult::NotInMemory => None,
527 }
528 }
529
530 pub fn verify(&self, comm: &RewardMerkleCommitmentV1) -> anyhow::Result<U256> {
531 match &self.proof {
532 RewardMerkleProofV1::Presence(proof) => {
533 ensure!(
534 RewardMerkleTreeV1::verify(comm, RewardAccountV1(self.account), proof)?.is_ok(),
535 "invalid proof"
536 );
537 Ok(proof
538 .elem()
539 .context("presence proof is missing account balance")?
540 .0)
541 },
542 RewardMerkleProofV1::Absence(proof) => {
543 let tree = RewardMerkleTreeV1::from_commitment(comm);
544 ensure!(
545 RewardMerkleTreeV1::non_membership_verify(
546 tree.commitment(),
547 RewardAccountV1(self.account),
548 proof
549 )?,
550 "invalid proof"
551 );
552 Ok(U256::ZERO)
553 },
554 }
555 }
556
557 pub fn remember(&self, tree: &mut RewardMerkleTreeV1) -> anyhow::Result<()> {
558 match &self.proof {
559 RewardMerkleProofV1::Presence(proof) => {
560 tree.remember(
561 RewardAccountV1(self.account),
562 proof
563 .elem()
564 .context("presence proof is missing account balance")?,
565 proof,
566 )?;
567 Ok(())
568 },
569 RewardMerkleProofV1::Absence(proof) => {
570 tree.non_membership_remember(RewardAccountV1(self.account), proof)?;
571 Ok(())
572 },
573 }
574 }
575}
576
577impl From<(RewardAccountProofV2, U256)> for RewardAccountQueryDataV2 {
578 fn from((proof, balance): (RewardAccountProofV2, U256)) -> Self {
579 Self { balance, proof }
580 }
581}
582
583#[derive(Clone, Debug)]
584pub struct ComputedRewards {
585 leader_address: Address,
586 leader_commission: RewardAmount,
588 delegators: Vec<(Address, RewardAmount)>,
590}
591
592impl ComputedRewards {
593 pub fn new(
594 delegators: Vec<(Address, RewardAmount)>,
595 leader_address: Address,
596 leader_commission: RewardAmount,
597 ) -> Self {
598 Self {
599 delegators,
600 leader_address,
601 leader_commission,
602 }
603 }
604
605 pub fn leader_commission(&self) -> &RewardAmount {
606 &self.leader_commission
607 }
608
609 pub fn delegators(&self) -> &Vec<(Address, RewardAmount)> {
610 &self.delegators
611 }
612
613 pub fn all_rewards(self) -> Vec<(Address, RewardAmount)> {
615 self.delegators
616 .into_iter()
617 .chain(once((self.leader_address, self.leader_commission)))
618 .collect()
619 }
620}
621
622pub struct ValidatorLeaderCounts(Vec<(AuthenticatedValidator<BLSPubKey>, u16)>);
624
625impl ValidatorLeaderCounts {
626 pub fn new(
631 membership: &EpochCommittees,
632 epoch: &EpochNumber,
633 leader_counts: LeaderCounts,
634 ) -> anyhow::Result<Self> {
635 let entries: Vec<_> = membership
636 .stake_table(Some(*epoch))
637 .iter()
638 .zip(leader_counts.iter().copied())
639 .map(|(entry, count)| {
640 let validator =
641 membership.get_validator_config(epoch, entry.stake_table_entry.stake_key)?;
642 Ok((validator, count))
643 })
644 .collect::<anyhow::Result<_>>()?;
645
646 Ok(Self(entries))
647 }
648
649 pub fn active_leaders(
651 &self,
652 ) -> impl Iterator<Item = (&AuthenticatedValidator<BLSPubKey>, u16)> {
653 self.0
654 .iter()
655 .filter(|(_, count)| *count > 0)
656 .map(|(v, count)| (v, *count))
657 }
658
659 fn all_reward_accounts(&self) -> Vec<RewardAccountV2> {
661 self.active_leaders()
662 .flat_map(|(v, _)| {
663 std::iter::once(RewardAccountV2(v.account))
664 .chain(v.delegators.keys().map(|d| RewardAccountV2(*d)))
665 })
666 .collect()
667 }
668}
669
670pub struct RewardDistributor {
671 validator: AuthenticatedValidator<BLSPubKey>,
672 block_reward: RewardAmount,
673 total_distributed: RewardAmount,
674}
675
676impl RewardDistributor {
677 pub fn new(
678 validator: AuthenticatedValidator<BLSPubKey>,
679 block_reward: RewardAmount,
680 total_distributed: RewardAmount,
681 ) -> Self {
682 Self {
683 validator,
684 block_reward,
685 total_distributed,
686 }
687 }
688
689 pub fn validator(&self) -> AuthenticatedValidator<BLSPubKey> {
690 self.validator.clone()
691 }
692
693 pub fn block_reward(&self) -> RewardAmount {
694 self.block_reward
695 }
696
697 pub fn total_distributed(&self) -> RewardAmount {
698 self.total_distributed
699 }
700
701 pub fn update_rewards_delta(&self, delta: &mut Delta) -> anyhow::Result<()> {
702 delta
704 .rewards_delta
705 .insert(RewardAccountV2(self.validator().account));
706 delta.rewards_delta.extend(
707 self.validator()
708 .delegators
709 .keys()
710 .map(|d| RewardAccountV2(*d)),
711 );
712
713 Ok(())
714 }
715
716 pub fn update_reward_balance<P>(
717 tree: &mut P,
718 account: &P::Index,
719 amount: P::Element,
720 ) -> anyhow::Result<()>
721 where
722 P: UniversalMerkleTreeScheme<Element = RewardAmount>,
723 P::Index: Borrow<<P as MerkleTreeScheme>::Index> + std::fmt::Display,
724 {
725 let mut err = None;
726 tree.update_with(account.clone(), |balance| {
727 let balance = balance.copied();
728 match balance.unwrap_or_default().0.checked_add(amount.0) {
729 Some(updated) => Some(updated.into()),
730 None => {
731 err = Some(format!("overflowed reward balance for account {account}"));
732 balance
733 },
734 }
735 })?;
736
737 if let Some(error) = err {
738 tracing::warn!(error);
739 bail!(error)
740 }
741
742 Ok(())
743 }
744
745 pub fn apply_rewards(
746 &mut self,
747 version: Version,
748 state: &mut ValidatedState,
749 ) -> anyhow::Result<()> {
750 let computed_rewards = self.compute_rewards()?;
751
752 if version <= EPOCH_VERSION {
753 for (address, reward) in computed_rewards.all_rewards() {
754 Self::update_reward_balance(
755 &mut state.reward_merkle_tree_v1,
756 &RewardAccountV1(address),
757 reward,
758 )?;
759 tracing::debug!(%address, %reward, "applied v1 rewards");
760 }
761 } else {
762 for (address, reward) in computed_rewards.all_rewards() {
763 Self::update_reward_balance(
764 &mut state.reward_merkle_tree_v2,
765 &RewardAccountV2(address),
766 reward,
767 )?;
768 tracing::debug!(%address, %reward, "applied v2 rewards");
769 }
770 }
771
772 self.total_distributed += self.block_reward();
773
774 Ok(())
775 }
776
777 pub fn compute_rewards(&self) -> anyhow::Result<ComputedRewards> {
785 ensure!(
786 self.validator.commission <= COMMISSION_BASIS_POINTS,
787 "commission must not exceed {COMMISSION_BASIS_POINTS}"
788 );
789
790 let mut rewards = Vec::new();
791
792 let total_reward = self.block_reward.0;
793 let delegators_ratio_basis_points = U256::from(COMMISSION_BASIS_POINTS)
794 .checked_sub(U256::from(self.validator.commission))
795 .context("overflow")?;
796 let delegators_reward = delegators_ratio_basis_points
797 .checked_mul(total_reward)
798 .context("overflow")?;
799
800 let total_stake = self.validator.stake;
802 let mut delegators_total_reward_distributed = U256::from(0);
803 for (delegator_address, delegator_stake) in &self.validator.delegators {
804 let delegator_reward = RewardAmount::from(
805 (delegator_stake
806 .checked_mul(delegators_reward)
807 .context("overflow")?
808 .checked_div(total_stake)
809 .context("overflow")?)
810 .checked_div(U256::from(COMMISSION_BASIS_POINTS))
811 .context("overflow")?,
812 );
813
814 delegators_total_reward_distributed += delegator_reward.0;
815
816 rewards.push((*delegator_address, delegator_reward));
817 }
818
819 let leader_commission = total_reward
820 .checked_sub(delegators_total_reward_distributed)
821 .context("overflow")?;
822
823 Ok(ComputedRewards::new(
824 rewards,
825 self.validator.account,
826 leader_commission.into(),
827 ))
828 }
829}
830
831pub async fn distribute_block_reward(
838 instance_state: &NodeState,
839 validated_state: &mut ValidatedState,
840 parent_leaf: &Leaf2,
841 view_number: ViewNumber,
842 version: Version,
843) -> anyhow::Result<Option<RewardDistributor>> {
844 let height = parent_leaf.height() + 1;
845
846 let epoch_height = instance_state
847 .epoch_height
848 .context("epoch height not found")?;
849 let epoch = EpochNumber::new(epoch_from_block_number(height, epoch_height));
850 let coordinator = instance_state.coordinator.clone();
851 let first_epoch = {
852 coordinator
853 .membership()
854 .read()
855 .await
856 .first_epoch()
857 .context("The first epoch was not set.")?
858 };
859
860 if epoch <= first_epoch + 1 {
863 return Ok(None);
864 }
865
866 let leader = get_leader_and_fetch_missing_rewards(
870 instance_state,
871 validated_state,
872 parent_leaf,
873 view_number,
874 )
875 .await?;
876
877 let parent_header = parent_leaf.block_header();
878
879 let mut previously_distributed = parent_header.total_reward_distributed().unwrap_or_default();
881
882 let block_reward = if version == DRB_AND_HEADER_UPGRADE_VERSION {
884 instance_state
885 .block_reward(EpochNumber::new(*epoch))
886 .await
887 .with_context(|| format!("block reward is None for epoch {epoch}"))?
888 } else {
889 instance_state.fixed_block_reward().await?
890 };
891
892 if version == DRB_AND_HEADER_UPGRADE_VERSION && parent_header.version() == EPOCH_VERSION {
897 ensure!(
898 instance_state.epoch_start_block != 0,
899 "epoch_start_block is zero"
900 );
901
902 let fixed_block_reward = instance_state.fixed_block_reward().await?;
903
904 let first_reward_block = (*first_epoch + 1) * epoch_height + 1;
910 if height > first_reward_block {
914 let blocks = height.checked_sub(first_reward_block).with_context(|| {
917 format!("height ({height}) - first_reward_block ({first_reward_block}) underflowed")
918 })?;
919 previously_distributed = U256::from(blocks)
920 .checked_mul(fixed_block_reward.0)
921 .with_context(|| {
922 format!(
923 "overflow during total_distributed calculation: blocks={blocks}, \
924 fixed_block_reward={}",
925 fixed_block_reward.0
926 )
927 })?
928 .into();
929 }
930 }
931
932 if block_reward.0.is_zero() {
933 tracing::info!("block reward is zero. height={height}. epoch={epoch}");
934 return Ok(None);
935 }
936
937 let mut reward_distributor =
938 RewardDistributor::new(leader, block_reward, previously_distributed);
939
940 reward_distributor.apply_rewards(version, validated_state)?;
941
942 Ok(Some(reward_distributor))
943}
944
945pub async fn get_leader_and_fetch_missing_rewards(
946 instance_state: &NodeState,
947 validated_state: &mut ValidatedState,
948 parent_leaf: &Leaf2,
949 view: ViewNumber,
950) -> anyhow::Result<AuthenticatedValidator<BLSPubKey>> {
951 let parent_height = parent_leaf.height();
952 let parent_view = parent_leaf.view_number();
953 let new_height = parent_height + 1;
954
955 let epoch_height = instance_state
956 .epoch_height
957 .context("epoch height not found")?;
958 if epoch_height == 0 {
959 bail!("epoch height is 0. can not catchup reward accounts");
960 }
961 let epoch = EpochNumber::new(epoch_from_block_number(new_height, epoch_height));
962
963 let coordinator = instance_state.coordinator.clone();
964
965 let epoch_membership = coordinator.membership_for_epoch(Some(epoch)).await?;
966 let membership = epoch_membership.coordinator.membership().read().await;
967
968 let leader: BLSPubKey = membership
969 .leader(view, Some(epoch))
970 .context(format!("leader for epoch {epoch:?} not found"))?;
971
972 tracing::debug!("Selected leader: {leader} for view {view} and epoch {epoch}");
973
974 let validator = membership
975 .get_validator_config(&epoch, leader)
976 .context("validator not found")?;
977 drop(membership);
978
979 let parent_header = parent_leaf.block_header();
980
981 if parent_header.version() <= EPOCH_VERSION {
982 let mut reward_accounts = HashSet::new();
983 reward_accounts.insert(validator.account.into());
984 let delegators = validator
985 .delegators
986 .keys()
987 .cloned()
988 .map(|a| a.into())
989 .collect::<Vec<RewardAccountV2>>();
990
991 reward_accounts.extend(delegators.clone());
992
993 let accts: HashSet<_> = reward_accounts
994 .into_iter()
995 .map(RewardAccountV1::from)
996 .collect();
997 let missing_reward_accts = validated_state.forgotten_reward_accounts_v1(accts);
998
999 if !missing_reward_accts.is_empty() {
1000 tracing::warn!(
1001 parent_height,
1002 ?parent_view,
1003 ?missing_reward_accts,
1004 "fetching missing v1 reward accounts from peers"
1005 );
1006
1007 let missing_account_proofs = instance_state
1008 .state_catchup
1009 .fetch_reward_accounts_v1(
1010 instance_state,
1011 parent_height,
1012 parent_view,
1013 validated_state.reward_merkle_tree_v1.commitment(),
1014 missing_reward_accts,
1015 )
1016 .await?;
1017
1018 for proof in missing_account_proofs.iter() {
1019 proof
1020 .remember(&mut validated_state.reward_merkle_tree_v1)
1021 .expect("proof previously verified");
1022 }
1023 }
1024 } else {
1025 let reward_accounts = Arc::new(
1026 std::iter::once(validator.account.into())
1027 .chain(validator.delegators.keys().cloned().map(Into::into))
1028 .collect::<Vec<_>>(),
1029 );
1030
1031 let reward_merkle_tree_root = validated_state.reward_merkle_tree_v2.commitment();
1032 if forgotten_accounts_include(&validated_state.reward_merkle_tree_v2, &reward_accounts) {
1033 tracing::warn!(
1034 parent_height,
1035 ?parent_view,
1036 %reward_merkle_tree_root,
1037 "fetching reward merkle tree from peers"
1038 );
1039
1040 validated_state.reward_merkle_tree_v2 = instance_state
1041 .state_catchup
1042 .fetch_reward_merkle_tree_v2(
1043 parent_height,
1044 parent_view,
1045 reward_merkle_tree_root,
1046 reward_accounts,
1047 )
1048 .await?
1049 .tree;
1050
1051 tracing::warn!(
1052 parent_height,
1053 ?parent_view,
1054 %reward_merkle_tree_root,
1055 "successfully fetched reward merkle tree from peers"
1056 );
1057 }
1058 }
1059
1060 Ok(validator)
1061}
1062
1063#[derive(Debug, Clone)]
1065pub struct EpochRewardsResult {
1066 pub epoch: EpochNumber,
1068 pub reward_tree: RewardMerkleTreeV2,
1070 pub total_distributed: RewardAmount,
1072 pub changed_accounts: HashSet<RewardAccountV2>,
1074}
1075
1076#[derive(Debug, Default)]
1079pub struct EpochRewardsCalculator {
1080 pending: Option<(EpochNumber, JoinHandle<anyhow::Result<EpochRewardsResult>>)>,
1082}
1083
1084impl EpochRewardsCalculator {
1085 pub fn new() -> Self {
1086 Self { pending: None }
1087 }
1088
1089 pub fn is_calculating(&self, epoch: EpochNumber) -> bool {
1091 self.pending.as_ref().is_some_and(|(e, _)| *e == epoch)
1092 }
1093
1094 pub async fn get_result(&mut self, epoch: EpochNumber) -> Option<EpochRewardsResult> {
1101 let (pending_epoch, handle) = self.pending.take()?;
1102 if pending_epoch != epoch {
1103 self.pending = Some((pending_epoch, handle));
1105 return None;
1106 }
1107
1108 match handle.await {
1109 Ok(Ok(result)) => {
1110 tracing::info!(%epoch, total = %result.total_distributed.0, "epoch rewards calculation completed");
1111 Some(result)
1112 },
1113 Ok(Err(e)) => {
1114 tracing::error!(%epoch, error = %e, "epoch rewards calculation failed");
1115 None
1116 },
1117 Err(e) => {
1118 tracing::error!(%epoch, error = %e, "epoch rewards task panicked");
1119 None
1120 },
1121 }
1122 }
1123
1124 pub fn spawn_background_task(
1128 &mut self,
1129 epoch: EpochNumber,
1130 epoch_height: u64,
1131 reward_tree: RewardMerkleTreeV2,
1132 instance_state: NodeState,
1133 coordinator: EpochMembershipCoordinator<SeqTypes>,
1134 leader_counts: Option<LeaderCounts>,
1135 ) {
1136 if self.is_calculating(epoch) {
1137 tracing::debug!(%epoch, "calculation already in progress, skipping");
1138 return;
1139 }
1140
1141 if let Some((stale_epoch, handle)) = self.pending.take() {
1143 tracing::info!(%stale_epoch, %epoch, "aborting stale epoch rewards task");
1144 handle.abort();
1145 }
1146
1147 tracing::info!(
1148 %epoch,
1149 has_leader_counts = leader_counts.is_some(),
1150 "starting background epoch rewards task"
1151 );
1152
1153 let handle = tokio::spawn(async move {
1154 Self::fetch_and_calculate(
1155 epoch,
1156 epoch_height,
1157 reward_tree,
1158 instance_state,
1159 coordinator,
1160 leader_counts,
1161 )
1162 .await
1163 });
1164 self.pending = Some((epoch, handle));
1165 }
1166
1167 async fn fetch_and_calculate(
1168 epoch: EpochNumber,
1169 epoch_height: u64,
1170 mut reward_tree: RewardMerkleTreeV2,
1171 instance_state: NodeState,
1172 coordinator: EpochMembershipCoordinator<SeqTypes>,
1173 leader_counts: Option<LeaderCounts>,
1174 ) -> anyhow::Result<EpochRewardsResult> {
1175 let epoch_last_block_height = (*epoch) * epoch_height;
1176
1177 tracing::info!(
1178 %epoch,
1179 epoch_last_block_height,
1180 has_leader_counts = leader_counts.is_some(),
1181 "fetch_and_calculate: starting"
1182 );
1183
1184 if let Err(err) = coordinator.membership_for_epoch(Some(epoch)).await {
1186 tracing::info!(%epoch, "stake table missing for epoch, triggering catchup: {err:#}");
1187 coordinator
1188 .wait_for_catchup(epoch)
1189 .await
1190 .context(format!("failed to catch up for epoch={epoch}"))?;
1191 }
1192
1193 let leader_counts = if let Some(lc) = leader_counts {
1195 lc
1196 } else {
1197 let membership = coordinator.membership().read().await;
1200 let stake_table = membership.stake_table(Some(epoch));
1201 let success_threshold = membership.success_threshold(Some(epoch));
1202 drop(membership);
1203
1204 let leaf = instance_state
1205 .state_catchup
1206 .as_ref()
1207 .fetch_leaf(epoch_last_block_height, stake_table, success_threshold)
1208 .await
1209 .with_context(|| {
1210 format!(
1211 "failed to fetch leaf at height {epoch_last_block_height} for epoch \
1212 {epoch}"
1213 )
1214 })?;
1215 let header = leaf.block_header();
1216
1217 tracing::info!(
1218 %epoch,
1219 header_height = header.height(),
1220 header_version = %header.version(),
1221 header_reward_merkle_tree_root = %header.reward_merkle_tree_root(),
1222 "fetch_and_calculate: fetched leaf"
1223 );
1224
1225 anyhow::ensure!(
1226 header.version() >= EPOCH_REWARD_VERSION,
1227 "header version {} is pre-V6, cannot calculate rewards",
1228 header.version()
1229 );
1230
1231 let expected_root = header.reward_merkle_tree_root().right();
1237 let actual_root = reward_tree.commitment();
1238 if expected_root != Some(actual_root) {
1239 tracing::warn!(
1240 %epoch,
1241 ?expected_root,
1242 ?actual_root,
1243 "reward merkle tree root mismatch, using empty tree"
1244 );
1245 reward_tree = RewardMerkleTreeV2::new(REWARD_MERKLE_TREE_V2_HEIGHT);
1246 }
1247
1248 *header
1249 .leader_counts()
1250 .expect("V6+ header must have leader_counts")
1251 };
1252
1253 let membership = coordinator.membership().read().await;
1254 let validator_leader_counts =
1255 ValidatorLeaderCounts::new(&membership, &epoch, leader_counts)?;
1256 let block_reward = membership
1257 .epoch_block_reward(epoch)
1258 .context("block reward not found for epoch")?;
1259 drop(membership);
1260
1261 tracing::info!(
1262 %epoch,
1263 %block_reward,
1264 "fetch_and_calculate: got block_reward"
1265 );
1266
1267 let accounts_to_update = validator_leader_counts.all_reward_accounts();
1269
1270 let missing_accounts: Vec<_> = accounts_to_update
1271 .iter()
1272 .filter(|account| reward_tree.lookup(**account).expect_not_in_memory().is_ok())
1273 .cloned()
1274 .collect();
1275
1276 if !missing_accounts.is_empty() {
1279 tracing::info!(
1280 %epoch,
1281 num_missing = missing_accounts.len(),
1282 "missing accounts detected, fetching reward merkle tree from peers"
1283 );
1284
1285 let reward_merkle_tree_root = reward_tree.commitment();
1286 reward_tree = instance_state
1287 .state_catchup
1288 .as_ref()
1289 .fetch_reward_merkle_tree_v2(
1290 epoch_last_block_height,
1291 ViewNumber::new(0),
1292 reward_merkle_tree_root,
1293 Arc::new(missing_accounts),
1294 )
1295 .await
1296 .with_context(|| {
1297 format!(
1298 "failed to fetch reward merkle tree at height {epoch_last_block_height} \
1299 for epoch {epoch}"
1300 )
1301 })?
1302 .tree;
1303
1304 tracing::info!(
1305 %epoch,
1306 reward_tree_commitment = %reward_tree.commitment(),
1307 "reward tree fetched successfully"
1308 );
1309 }
1310
1311 tracing::info!(
1312 %epoch,
1313 reward_tree_commitment = %reward_tree.commitment(),
1314 "starting final epoch calculation"
1315 );
1316
1317 Self::calculate_all_rewards(epoch, validator_leader_counts, reward_tree, block_reward).await
1318 }
1319
1320 async fn calculate_all_rewards(
1322 epoch: EpochNumber,
1323 validator_leader_counts: ValidatorLeaderCounts,
1324 mut reward_tree: RewardMerkleTreeV2,
1325 block_reward: RewardAmount,
1326 ) -> anyhow::Result<EpochRewardsResult> {
1327 let mut total_distributed = U256::ZERO;
1328 let mut changed_accounts = HashSet::new();
1329
1330 for (validator, count) in validator_leader_counts.active_leaders() {
1331 let validator_reward = block_reward
1333 .0
1334 .checked_mul(U256::from(count))
1335 .context("overflow in validator reward calculation")?;
1336
1337 if validator_reward.is_zero() {
1338 continue;
1339 }
1340
1341 changed_accounts.insert(RewardAccountV2(validator.account));
1342 changed_accounts.extend(validator.delegators.keys().map(|d| RewardAccountV2(*d)));
1343
1344 let distributor = RewardDistributor::new(
1345 validator.clone(),
1346 RewardAmount(validator_reward),
1347 Default::default(),
1348 );
1349
1350 let computed_rewards = distributor.compute_rewards()?;
1351
1352 for (address, reward) in computed_rewards.all_rewards() {
1353 RewardDistributor::update_reward_balance(
1354 &mut reward_tree,
1355 &RewardAccountV2(address),
1356 reward,
1357 )?;
1358 tracing::debug!(%epoch, %address, %reward, "applied epoch reward");
1359 }
1360
1361 total_distributed += validator_reward;
1362 }
1363
1364 tracing::info!(
1365 %epoch,
1366 total_distributed = %total_distributed,
1367 num_changed_accounts = changed_accounts.len(),
1368 "epoch rewards calculation complete"
1369 );
1370
1371 Ok(EpochRewardsResult {
1372 epoch,
1373 reward_tree,
1374 total_distributed: RewardAmount(total_distributed),
1375 changed_accounts,
1376 })
1377 }
1378}
1379
1380#[cfg(test)]
1381pub mod tests {
1382
1383 use super::*;
1384
1385 fn make_distributor(commission: u16) -> RewardDistributor {
1386 RewardDistributor::new(
1387 AuthenticatedValidator::mock_with_commission(commission),
1388 RewardAmount(U256::from(1902000000000000000_u128)),
1389 U256::ZERO.into(),
1390 )
1391 }
1392
1393 fn total_rewards(rewards: ComputedRewards) -> U256 {
1394 rewards
1395 .all_rewards()
1396 .iter()
1397 .fold(U256::ZERO, |acc, (_, r)| acc + r.0)
1398 }
1399
1400 #[test]
1403 fn test_reward_calculation_sanity_checks() {
1404 let distributor = make_distributor(500);
1409 let rewards = distributor.compute_rewards().unwrap();
1410 assert_eq!(total_rewards(rewards.clone()), distributor.block_reward.0);
1411
1412 let distributor = make_distributor(0);
1413 let rewards = distributor.compute_rewards().unwrap();
1414 assert_eq!(total_rewards(rewards.clone()), distributor.block_reward.0);
1415
1416 let distributor = make_distributor(10000);
1417 let rewards = distributor.compute_rewards().unwrap();
1418 assert_eq!(total_rewards(rewards.clone()), distributor.block_reward.0);
1419 let leader_commission = rewards.leader_commission();
1420 assert_eq!(*leader_commission, distributor.block_reward);
1421
1422 let distributor = make_distributor(10001);
1423 assert!(
1424 distributor
1425 .compute_rewards()
1426 .err()
1427 .unwrap()
1428 .to_string()
1429 .contains("must not exceed")
1430 );
1431 }
1432
1433 #[test]
1434 fn test_compute_rewards_validator_commission() {
1435 let distributor = make_distributor(0);
1436 let rewards = distributor.compute_rewards().unwrap();
1437 let leader_commission = rewards.leader_commission();
1438 let percentage =
1439 leader_commission.0 * U256::from(COMMISSION_BASIS_POINTS) / distributor.block_reward.0;
1440 assert_eq!(percentage, U256::ZERO);
1441
1442 let distributor = make_distributor(300);
1444 let rewards = distributor.compute_rewards().unwrap();
1445 let leader_commission = rewards.leader_commission();
1446 let percentage =
1447 leader_commission.0 * U256::from(COMMISSION_BASIS_POINTS) / distributor.block_reward.0;
1448 println!("percentage: {percentage:?}");
1449 assert_eq!(percentage, U256::from(300));
1450
1451 let distributor = make_distributor(10000);
1453 let rewards = distributor.compute_rewards().unwrap();
1454 let leader_commission = rewards.leader_commission();
1455 assert_eq!(*leader_commission, distributor.block_reward);
1456 }
1457}