vid/
avidm_gf2.rs

1//! This module implements the AVID-M scheme over GF2
2
3use std::{ops::Range, vec};
4
5use anyhow::anyhow;
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use jf_merkle_tree::{MerkleTreeScheme, hasher::HasherNode};
8use jf_utils::canonical;
9use serde::{Deserialize, Serialize};
10use sha2::Digest;
11use tagged_base64::tagged;
12
13use crate::{VidError, VidResult, VidScheme};
14
15/// Namespaced AvidmGf2 scheme
16pub mod namespaced;
17/// Namespace proofs for AvidmGf2 scheme
18pub mod proofs;
19
20/// Merkle tree scheme used in the VID
21pub(crate) type MerkleTree =
22    jf_merkle_tree::hasher::HasherMerkleTree<sha3::Keccak256, HasherNode<sha3::Keccak256>>;
23type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
24type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
25
26/// Dummy struct for AVID-M scheme over GF2
27pub struct AvidmGf2Scheme;
28
29/// VID Parameters
30#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
31pub struct AvidmGf2Param {
32    /// Total weights of all storage nodes
33    pub total_weights: usize,
34    /// Minimum collective weights required to recover the original payload.
35    pub recovery_threshold: usize,
36}
37
38impl AvidmGf2Param {
39    /// Construct a new [`AvidmGf2Param`].
40    pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
41        if recovery_threshold == 0 || total_weights < recovery_threshold {
42            return Err(VidError::InvalidParam);
43        }
44        Ok(Self {
45            total_weights,
46            recovery_threshold,
47        })
48    }
49}
50
51/// VID Share type to be distributed among the parties.
52#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
53pub struct AvidmGf2Share {
54    /// Range of this share in the encoded payload.
55    range: Range<usize>,
56    /// Actual share content.
57    #[serde(with = "canonical")]
58    payload: Vec<Vec<u8>>,
59    /// Merkle proof of the content.
60    #[serde(with = "canonical")]
61    mt_proofs: Vec<MerkleProof>,
62}
63
64impl AvidmGf2Share {
65    /// Get the weight of this share
66    pub fn weight(&self) -> usize {
67        self.range.len()
68    }
69
70    /// Validate the share structure.
71    pub fn validate(&self) -> bool {
72        self.payload.len() == self.range.len() && self.mt_proofs.len() == self.range.len()
73    }
74}
75
76/// VID Commitment type
77#[derive(
78    Clone,
79    Copy,
80    Debug,
81    Default,
82    Hash,
83    CanonicalSerialize,
84    CanonicalDeserialize,
85    Eq,
86    PartialEq,
87    Ord,
88    PartialOrd,
89)]
90#[tagged("AvidmGf2Commit")]
91#[repr(C)]
92pub struct AvidmGf2Commit {
93    /// VID commitment is the Merkle tree root
94    pub commit: MerkleCommit,
95}
96
97impl AsRef<[u8]> for AvidmGf2Commit {
98    fn as_ref(&self) -> &[u8] {
99        self.commit.as_ref()
100    }
101}
102
103impl AsRef<[u8; 32]> for AvidmGf2Commit {
104    fn as_ref(&self) -> &[u8; 32] {
105        <Self as AsRef<[u8]>>::as_ref(self)
106            .try_into()
107            .expect("AvidmGf2Commit is always 32 bytes")
108    }
109}
110
111impl AvidmGf2Scheme {
112    /// Setup an instance for AVID-M scheme
113    pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidmGf2Param> {
114        AvidmGf2Param::new(recovery_threshold, total_weights)
115    }
116
117    fn bit_padding(payload: &[u8], payload_len: usize) -> VidResult<Vec<u8>> {
118        if payload_len < payload.len() + 1 {
119            return Err(VidError::Argument(
120                "Payload length is too large to fit in the given payload length".to_string(),
121            ));
122        }
123        let mut padded = vec![0u8; payload_len];
124        padded[..payload.len()].copy_from_slice(payload);
125        padded[payload.len()] = 1u8;
126        Ok(padded)
127    }
128
129    fn raw_disperse(
130        param: &AvidmGf2Param,
131        payload: &[u8],
132    ) -> VidResult<(MerkleTree, Vec<Vec<u8>>)> {
133        let original_count = param.recovery_threshold;
134        let recovery_count = param.total_weights - param.recovery_threshold;
135        // Bit padding, we append an 1u8 to the end of the payload.
136        let mut shard_bytes = (payload.len() + 1).div_ceil(original_count);
137        if shard_bytes % 2 == 1 {
138            shard_bytes += 1;
139        }
140        let payload = Self::bit_padding(payload, shard_bytes * original_count)?;
141        let original = payload
142            .chunks(shard_bytes)
143            .map(|chunk| chunk.to_owned())
144            .collect::<Vec<_>>();
145        let recovery = if recovery_count == 0 {
146            vec![]
147        } else {
148            reed_solomon_simd::encode(original_count, recovery_count, &original)?
149        };
150
151        let shares = [original, recovery].concat();
152        let share_digests: Vec<_> = shares
153            .iter()
154            .map(|share| HasherNode::from(sha3::Keccak256::digest(share)))
155            .collect();
156        let mt = MerkleTree::from_elems(None, &share_digests)?;
157        Ok((mt, shares))
158    }
159}
160
161impl VidScheme for AvidmGf2Scheme {
162    type Param = AvidmGf2Param;
163    type Share = AvidmGf2Share;
164    type Commit = AvidmGf2Commit;
165
166    fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
167        let (mt, _) = Self::raw_disperse(param, payload)?;
168        Ok(Self::Commit {
169            commit: mt.commitment(),
170        })
171    }
172
173    fn disperse(
174        param: &Self::Param,
175        distribution: &[u32],
176        payload: &[u8],
177    ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
178        let total_weights = distribution.iter().map(|&w| w as usize).sum::<usize>();
179        if total_weights != param.total_weights {
180            return Err(VidError::Argument(
181                "Weight distribution is inconsistent with the given param".to_string(),
182            ));
183        }
184        if distribution.contains(&0u32) {
185            return Err(VidError::Argument("Weight cannot be zero".to_string()));
186        }
187        let (mt, shares) = Self::raw_disperse(param, payload)?;
188        let commit = AvidmGf2Commit {
189            commit: mt.commitment(),
190        };
191        let ranges: Vec<_> = distribution
192            .iter()
193            .scan(0usize, |sum, w| {
194                let prefix_sum = *sum;
195                *sum += *w as usize;
196                Some(prefix_sum..*sum)
197            })
198            .collect();
199        let shares: Vec<_> = ranges
200            .into_iter()
201            .map(|range| AvidmGf2Share {
202                range: range.clone(),
203                payload: shares[range.clone()].to_vec(),
204                // TODO(Chengyu): switch to batch proof generation
205                mt_proofs: range
206                    .map(|k| {
207                        mt.lookup(k as u64)
208                            .expect_ok()
209                            .expect("MT lookup shouldn't fail")
210                            .1
211                    })
212                    .collect::<Vec<_>>(),
213            })
214            .collect();
215        Ok((commit, shares))
216    }
217
218    fn verify_share(
219        param: &Self::Param,
220        commit: &Self::Commit,
221        share: &Self::Share,
222    ) -> VidResult<crate::VerificationResult> {
223        if !share.validate() || share.range.is_empty() || share.range.end > param.total_weights {
224            return Err(VidError::InvalidShare);
225        }
226        for (i, index) in share.range.clone().enumerate() {
227            let payload_digest = HasherNode::from(sha3::Keccak256::digest(&share.payload[i]));
228            // TODO(Chengyu): switch to batch verification
229            if MerkleTree::verify(
230                commit.commit,
231                index as u64,
232                payload_digest,
233                &share.mt_proofs[i],
234            )?
235            .is_err()
236            {
237                return Ok(Err(()));
238            }
239        }
240        Ok(Ok(()))
241    }
242
243    fn recover(
244        param: &Self::Param,
245        _commit: &Self::Commit,
246        shares: &[Self::Share],
247    ) -> VidResult<Vec<u8>> {
248        let original_count = param.recovery_threshold;
249        let recovery_count = param.total_weights - param.recovery_threshold;
250        // Find the first non-empty share
251        let Some(first_share) = shares.iter().find(|s| !s.payload.is_empty()) else {
252            return Err(VidError::InsufficientShares);
253        };
254        let shard_bytes = first_share.payload[0].len();
255
256        let mut original_shares: Vec<Option<Vec<u8>>> = vec![None; original_count];
257        if recovery_count == 0 {
258            // Edge case where there are no recovery shares
259            for share in shares {
260                if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
261                    return Err(VidError::InvalidShare);
262                }
263                for (i, index) in share.range.clone().enumerate() {
264                    if index < original_count {
265                        original_shares[index] = Some(share.payload[i].clone());
266                    }
267                }
268            }
269        } else {
270            let mut decoder = reed_solomon_simd::ReedSolomonDecoder::new(
271                original_count,
272                recovery_count,
273                shard_bytes,
274            )?;
275            for share in shares {
276                if !share.validate() || share.payload.iter().any(|p| p.len() != shard_bytes) {
277                    return Err(VidError::InvalidShare);
278                }
279                for (i, index) in share.range.clone().enumerate() {
280                    if index < original_count {
281                        original_shares[index] = Some(share.payload[i].clone());
282                        decoder.add_original_shard(index, &share.payload[i])?;
283                    } else {
284                        decoder.add_recovery_shard(index - original_count, &share.payload[i])?;
285                    }
286                }
287            }
288
289            let result = decoder.decode()?;
290            original_shares
291                .iter_mut()
292                .enumerate()
293                .for_each(|(i, share)| {
294                    if share.is_none() {
295                        *share = result.restored_original(i).map(|s| s.to_vec());
296                    }
297                });
298        }
299        if original_shares.iter().any(|share| share.is_none()) {
300            return Err(VidError::Internal(anyhow!(
301                "Failed to recover the payload."
302            )));
303        }
304        let mut recovered: Vec<_> = original_shares
305            .into_iter()
306            .flat_map(|share| share.unwrap())
307            .collect();
308        match recovered.iter().rposition(|&b| b != 0) {
309            Some(pad_index) if recovered[pad_index] == 1u8 => {
310                recovered.truncate(pad_index);
311                Ok(recovered)
312            },
313            _ => Err(VidError::Argument(
314                "Malformed payload, cannot find the padding position".to_string(),
315            )),
316        }
317    }
318}
319
320/// Unit tests
321#[cfg(test)]
322pub mod tests {
323    use rand::{RngCore, seq::SliceRandom};
324
325    use super::AvidmGf2Scheme;
326    use crate::VidScheme;
327
328    #[test]
329    fn round_trip() {
330        // play with these items
331        let num_storage_nodes_list = [4, 9, 16];
332        let payload_byte_lens = [1, 31, 32, 500];
333
334        // more items as a function of the above
335
336        let mut rng = jf_utils::test_rng();
337
338        for num_storage_nodes in num_storage_nodes_list {
339            let weights: Vec<u32> = (0..num_storage_nodes)
340                .map(|_| rng.next_u32() % 5 + 1)
341                .collect();
342            let total_weights: u32 = weights.iter().sum();
343            let recovery_threshold = total_weights.div_ceil(3) as usize;
344            let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
345
346            for payload_byte_len in payload_byte_lens {
347                let payload = {
348                    let mut bytes_random = vec![0u8; payload_byte_len];
349                    rng.fill_bytes(&mut bytes_random);
350                    bytes_random
351                };
352
353                let (commit, mut shares) =
354                    AvidmGf2Scheme::disperse(&params, &weights, &payload).unwrap();
355
356                assert_eq!(shares.len(), num_storage_nodes);
357
358                // verify shares
359                shares.iter().for_each(|share| {
360                    assert!(
361                        AvidmGf2Scheme::verify_share(&params, &commit, share)
362                            .is_ok_and(|r| r.is_ok())
363                    )
364                });
365
366                // test payload recovery on a random subset of shares
367                shares.shuffle(&mut rng);
368                let mut cumulated_weights = 0;
369                let mut cut_index = 0;
370                while cumulated_weights < recovery_threshold {
371                    cumulated_weights += shares[cut_index].weight();
372                    cut_index += 1;
373                }
374                let payload_recovered =
375                    AvidmGf2Scheme::recover(&params, &commit, &shares[..cut_index]).unwrap();
376                assert_eq!(payload_recovered, payload);
377            }
378        }
379    }
380
381    #[test]
382    fn round_trip_edge_case() {
383        // play with these items
384        let num_storage_nodes_list = [4, 9, 16];
385        let payload_byte_lens = [1, 31, 32, 500];
386
387        // more items as a function of the above
388
389        let mut rng = jf_utils::test_rng();
390
391        for num_storage_nodes in num_storage_nodes_list {
392            let weights: Vec<u32> = (0..num_storage_nodes)
393                .map(|_| rng.next_u32() % 5 + 1)
394                .collect();
395            let total_weights: u32 = weights.iter().sum();
396            let recovery_threshold = total_weights as usize;
397            let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights as usize).unwrap();
398
399            for payload_byte_len in payload_byte_lens {
400                let payload = {
401                    let mut bytes_random = vec![0u8; payload_byte_len];
402                    rng.fill_bytes(&mut bytes_random);
403                    bytes_random
404                };
405
406                let (commit, mut shares) =
407                    AvidmGf2Scheme::disperse(&params, &weights, &payload).unwrap();
408
409                assert_eq!(shares.len(), num_storage_nodes);
410
411                // verify shares
412                shares.iter().for_each(|share| {
413                    assert!(
414                        AvidmGf2Scheme::verify_share(&params, &commit, share)
415                            .is_ok_and(|r| r.is_ok())
416                    )
417                });
418
419                // test payload recovery on a random subset of shares
420                shares.shuffle(&mut rng);
421                let payload_recovered =
422                    AvidmGf2Scheme::recover(&params, &commit, &shares[..]).unwrap();
423                assert_eq!(payload_recovered, payload);
424            }
425        }
426    }
427
428    #[test]
429    fn disperse_rejects_inconsistent_distribution() {
430        let total_weights = 10usize;
431        let recovery_threshold = 4;
432        let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights).unwrap();
433        let payload = vec![1u8; 100];
434
435        // distribution sums to 12, but param says total_weights=10
436        let bad_weights = vec![3u32; 4];
437        assert!(
438            AvidmGf2Scheme::disperse(&params, &bad_weights, &payload).is_err(),
439            "disperse should reject distribution that doesn't sum to total_weights"
440        );
441
442        // distribution contains a zero weight
443        let zero_weight = vec![0u32, 5, 5];
444        assert!(
445            AvidmGf2Scheme::disperse(&params, &zero_weight, &payload).is_err(),
446            "disperse should reject zero-weight entries"
447        );
448
449        // correct distribution should succeed
450        let good_weights = vec![2u32; 5];
451        assert!(AvidmGf2Scheme::disperse(&params, &good_weights, &payload).is_ok());
452    }
453
454    #[test]
455    fn verify_share_rejects_out_of_range() {
456        let total_weights = 10usize;
457        let recovery_threshold = 4;
458        let params = AvidmGf2Scheme::setup(recovery_threshold, total_weights).unwrap();
459        let payload = vec![1u8; 100];
460        let weights = vec![2u32; 5];
461
462        let (commit, shares) = AvidmGf2Scheme::disperse(&params, &weights, &payload).unwrap();
463
464        // valid shares pass
465        for share in &shares {
466            assert!(AvidmGf2Scheme::verify_share(&params, &commit, share).is_ok_and(|r| r.is_ok()));
467        }
468
469        // a share verified against a smaller param should be rejected
470        let smaller_params = AvidmGf2Scheme::setup(2, 5).unwrap();
471        let last_share = shares.last().unwrap();
472        assert!(
473            AvidmGf2Scheme::verify_share(&smaller_params, &commit, last_share).is_err(),
474            "verify_share should reject share with range.end > param.total_weights"
475        );
476    }
477}