1use std::{collections::HashMap, fmt, iter, ops::Range};
15
16use ark_ff::PrimeField;
17use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
18use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
19use ark_std::{end_timer, start_timer};
20use config::AvidMConfig;
21use jf_merkle_tree::MerkleTreeScheme;
22use jf_utils::canonical;
23use p3_maybe_rayon::prelude::{
24 IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, ParallelSlice,
25};
26use serde::{Deserialize, Serialize};
27use tagged_base64::tagged;
28
29use crate::{
30 VidError, VidResult, VidScheme,
31 utils::bytes_to_field::{self, bytes_to_field, field_to_bytes},
32};
33
34mod config;
35
36pub mod namespaced;
37pub mod proofs;
38
39#[cfg(all(not(feature = "sha256"), not(feature = "keccak256")))]
40type Config = config::Poseidon2Config;
41#[cfg(feature = "sha256")]
42type Config = config::Sha256Config;
43#[cfg(feature = "keccak256")]
44type Config = config::Keccak256Config;
45
46type F = <Config as AvidMConfig>::BaseField;
48type MerkleTree = <Config as AvidMConfig>::MerkleTree;
49type MerkleProof = <MerkleTree as MerkleTreeScheme>::MembershipProof;
50type MerkleCommit = <MerkleTree as MerkleTreeScheme>::Commitment;
51
52#[derive(
54 Clone,
55 Copy,
56 Debug,
57 Default,
58 Hash,
59 CanonicalSerialize,
60 CanonicalDeserialize,
61 Eq,
62 PartialEq,
63 Ord,
64 PartialOrd,
65)]
66#[tagged("AvidMCommit")]
67#[repr(C)]
68pub struct AvidMCommit {
69 pub commit: MerkleCommit,
71}
72
73impl AsRef<[u8]> for AvidMCommit {
74 fn as_ref(&self) -> &[u8] {
75 unsafe {
76 ::core::slice::from_raw_parts(
77 (self as *const Self) as *const u8,
78 ::core::mem::size_of::<Self>(),
79 )
80 }
81 }
82}
83
84impl AsRef<[u8; 32]> for AvidMCommit {
85 fn as_ref(&self) -> &[u8; 32] {
86 unsafe { ::core::slice::from_raw_parts((self as *const Self) as *const u8, 32) }
87 .try_into()
88 .unwrap()
89 }
90}
91
92#[derive(Clone, Hash, Serialize, Deserialize, Eq, PartialEq)]
94pub struct RawAvidMShare {
95 range: Range<usize>,
97 #[serde(with = "canonical")]
99 payload: Vec<Vec<F>>,
100 #[serde(with = "canonical")]
102 mt_proofs: Vec<MerkleProof>,
103}
104
105impl fmt::Debug for RawAvidMShare {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 f.debug_struct("RawAvidMShare")
108 .field("range", &self.range)
109 .field("payload", &format_args!("..."))
110 .field("mt_proofs", &format_args!("..."))
111 .finish()
112 }
113}
114
115#[derive(Clone, Debug, Hash, Serialize, Deserialize, Eq, PartialEq)]
117pub struct AvidMShare {
118 index: u32,
120 payload_byte_len: usize,
122 content: RawAvidMShare,
124}
125
126#[derive(Clone, Debug, Hash, Serialize, Deserialize, PartialEq, Eq)]
128pub struct AvidMParam {
129 pub total_weights: usize,
131 pub recovery_threshold: usize,
133}
134
135impl AvidMParam {
136 pub fn new(recovery_threshold: usize, total_weights: usize) -> VidResult<Self> {
138 if recovery_threshold == 0 || total_weights < recovery_threshold {
139 return Err(VidError::InvalidParam);
140 }
141 Ok(Self {
142 total_weights,
143 recovery_threshold,
144 })
145 }
146}
147
148#[inline]
150fn radix2_domain<F: PrimeField>(domain_size: usize) -> VidResult<Radix2EvaluationDomain<F>> {
151 Radix2EvaluationDomain::<F>::new(domain_size).ok_or_else(|| VidError::InvalidParam)
152}
153
154pub struct AvidMScheme;
156
157impl AvidMScheme {
158 pub fn setup(recovery_threshold: usize, total_weights: usize) -> VidResult<AvidMParam> {
160 AvidMParam::new(recovery_threshold, total_weights)
161 }
162}
163
164impl AvidMScheme {
165 fn pad_to_fields(param: &AvidMParam, payload: &[u8]) -> Vec<F> {
170 let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
172
173 let num_bytes_per_chunk = param.recovery_threshold * elem_bytes_len;
176
177 let remainder = (payload.len() + 1) % num_bytes_per_chunk;
178 let pad_num_zeros = (num_bytes_per_chunk - remainder) % num_bytes_per_chunk;
179
180 bytes_to_field::<_, F>(
182 payload
183 .iter()
184 .chain(iter::once(&1u8))
185 .chain(iter::repeat_n(&0u8, pad_num_zeros)),
186 )
187 .collect()
188 }
189
190 #[allow(clippy::type_complexity)]
198 #[inline]
199 fn raw_encode(param: &AvidMParam, payload: &[F]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
200 let domain = radix2_domain::<F>(param.total_weights)?; let encoding_timer = start_timer!(|| "Encoding payload");
203
204 let codewords: Vec<_> = payload
206 .par_chunks(param.recovery_threshold)
207 .map(|chunk| {
208 let mut fft_vec = domain.fft(chunk); fft_vec.truncate(param.total_weights); fft_vec
211 })
212 .collect();
213 let raw_shares: Vec<_> = (0..param.total_weights)
216 .into_par_iter()
217 .map(|i| codewords.iter().map(|v| v[i]).collect::<Vec<F>>())
218 .collect();
219 end_timer!(encoding_timer);
220
221 let hash_timer = start_timer!(|| "Compressing each raw share");
222 let compressed_raw_shares = raw_shares
223 .par_iter()
224 .map(|v| Config::raw_share_digest(v))
225 .collect::<Result<Vec<_>, _>>()?;
226 end_timer!(hash_timer);
227
228 let mt_timer = start_timer!(|| "Constructing Merkle tree");
229 let mt = MerkleTree::from_elems(None, &compressed_raw_shares)?;
230 end_timer!(mt_timer);
231
232 Ok((mt, raw_shares))
233 }
234
235 fn pad_and_encode(param: &AvidMParam, payload: &[u8]) -> VidResult<(MerkleTree, Vec<Vec<F>>)> {
237 let payload = Self::pad_to_fields(param, payload);
238 Self::raw_encode(param, &payload)
239 }
240
241 fn distribute_shares(
243 param: &AvidMParam,
244 distribution: &[u32],
245 mt: MerkleTree,
246 raw_shares: Vec<Vec<F>>,
247 payload_byte_len: usize,
248 ) -> VidResult<(AvidMCommit, Vec<AvidMShare>)> {
249 let total_weights = distribution.iter().map(|&w| w as usize).sum::<usize>();
251 if total_weights != param.total_weights {
252 return Err(VidError::Argument(
253 "Weight distribution is inconsistent with the given param".to_string(),
254 ));
255 }
256 if distribution.contains(&0u32) {
257 return Err(VidError::Argument("Weight cannot be zero".to_string()));
258 }
259
260 let distribute_timer = start_timer!(|| "Distribute codewords to the storage nodes");
261 let ranges: Vec<_> = distribution
265 .iter()
266 .scan(0usize, |sum, w| {
267 let prefix_sum = *sum;
268 *sum += *w as usize;
269 Some(prefix_sum..*sum)
270 })
271 .collect();
272 let shares: Vec<_> = ranges
273 .par_iter()
274 .map(|range| {
275 range
276 .clone()
277 .map(|k| raw_shares[k].to_owned())
278 .collect::<Vec<_>>()
279 })
280 .collect();
281 end_timer!(distribute_timer);
282
283 let mt_proof_timer = start_timer!(|| "Generate Merkle tree proofs");
284 let shares = shares
285 .into_iter()
286 .enumerate()
287 .map(|(i, payload)| AvidMShare {
288 index: i as u32,
289 payload_byte_len,
290 content: RawAvidMShare {
291 range: ranges[i].clone(),
292 payload,
293 mt_proofs: ranges[i]
294 .clone()
295 .map(|k| {
296 mt.lookup(k as u64)
297 .expect_ok()
298 .expect("MT lookup shouldn't fail")
299 .1
300 })
301 .collect::<Vec<_>>(),
302 },
303 })
304 .collect::<Vec<_>>();
305 end_timer!(mt_proof_timer);
306
307 let commit = AvidMCommit {
308 commit: mt.commitment(),
309 };
310
311 Ok((commit, shares))
312 }
313
314 pub(crate) fn verify_internal(
315 param: &AvidMParam,
316 commit: &AvidMCommit,
317 share: &RawAvidMShare,
318 ) -> VidResult<crate::VerificationResult> {
319 if share.range.is_empty()
320 || share.range.end > param.total_weights
321 || share.range.len() != share.payload.len()
322 || share.range.len() != share.mt_proofs.len()
323 {
324 return Err(VidError::InvalidShare);
325 }
326 for (i, index) in share.range.clone().enumerate() {
327 let compressed_payload = Config::raw_share_digest(&share.payload[i])?;
328 if MerkleTree::verify(
329 commit.commit,
330 index as u64,
331 compressed_payload,
332 &share.mt_proofs[i],
333 )?
334 .is_err()
335 {
336 return Ok(Err(()));
337 }
338 }
339 Ok(Ok(()))
340 }
341
342 pub(crate) fn recover_fields(param: &AvidMParam, shares: &[AvidMShare]) -> VidResult<Vec<F>> {
343 let recovery_threshold: usize = param.recovery_threshold;
344
345 let num_polys = shares
348 .iter()
349 .find(|s| !s.content.payload.is_empty())
350 .ok_or(VidError::Argument("All shares are empty".to_string()))?
351 .content
352 .payload[0]
353 .len();
354
355 let mut raw_shares = HashMap::new();
356 for share in shares {
357 if share.content.range.len() != share.content.payload.len()
358 || share.content.range.end > param.total_weights
359 {
360 return Err(VidError::InvalidShare);
361 }
362 for (i, p) in share.content.range.clone().zip(&share.content.payload) {
363 if p.len() != num_polys {
364 return Err(VidError::InvalidShare);
365 }
366 if raw_shares.contains_key(&i) {
367 return Err(VidError::InvalidShare);
368 }
369 raw_shares.insert(i, p);
370 if raw_shares.len() >= recovery_threshold {
371 break;
372 }
373 }
374 if raw_shares.len() >= recovery_threshold {
375 break;
376 }
377 }
378
379 if raw_shares.len() < recovery_threshold {
380 return Err(VidError::InsufficientShares);
381 }
382
383 let domain = radix2_domain::<F>(param.total_weights)?;
384
385 let (x, raw_shares): (Vec<_>, Vec<_>) = raw_shares
388 .into_iter()
389 .map(|(i, p)| (domain.element(i), p))
390 .unzip();
391 Ok((0..num_polys)
393 .into_par_iter()
394 .map(|poly_index| {
395 jf_utils::reed_solomon_code::reed_solomon_erasure_decode(
396 x.iter().zip(raw_shares.iter().map(|p| p[poly_index])),
397 recovery_threshold,
398 )
399 .map_err(|err| VidError::Internal(err.into()))
400 })
401 .collect::<Result<Vec<_>, _>>()?
402 .into_iter()
403 .flatten()
404 .collect())
405 }
406}
407
408impl VidScheme for AvidMScheme {
409 type Param = AvidMParam;
410
411 type Share = AvidMShare;
412
413 type Commit = AvidMCommit;
414
415 fn commit(param: &Self::Param, payload: &[u8]) -> VidResult<Self::Commit> {
416 let (mt, _) = Self::pad_and_encode(param, payload)?;
417 Ok(AvidMCommit {
418 commit: mt.commitment(),
419 })
420 }
421
422 fn disperse(
423 param: &Self::Param,
424 distribution: &[u32],
425 payload: &[u8],
426 ) -> VidResult<(Self::Commit, Vec<Self::Share>)> {
427 let (mt, raw_shares) = Self::pad_and_encode(param, payload)?;
428 Self::distribute_shares(param, distribution, mt, raw_shares, payload.len())
429 }
430
431 fn verify_share(
432 param: &Self::Param,
433 commit: &Self::Commit,
434 share: &Self::Share,
435 ) -> VidResult<crate::VerificationResult> {
436 Self::verify_internal(param, commit, &share.content)
437 }
438
439 fn recover(
448 param: &Self::Param,
449 _commit: &Self::Commit,
450 shares: &[Self::Share],
451 ) -> VidResult<Vec<u8>> {
452 let mut bytes: Vec<u8> = field_to_bytes(Self::recover_fields(param, shares)?).collect();
453 if let Some(pad_index) = bytes.iter().rposition(|&b| b != 0)
456 && bytes[pad_index] == 1u8
457 {
458 bytes.truncate(pad_index);
459 return Ok(bytes);
460 }
461 Err(VidError::Argument(
462 "Malformed payload, cannot find the padding position".to_string(),
463 ))
464 }
465}
466
467#[cfg(test)]
469pub mod tests {
470 use rand::{RngCore, seq::SliceRandom};
471
472 use super::F;
473 use crate::{VidScheme, avidm::AvidMScheme, utils::bytes_to_field};
474
475 #[test]
476 fn test_padding() {
477 let elem_bytes_len = bytes_to_field::elem_byte_capacity::<F>();
478 let param = AvidMScheme::setup(2usize, 5usize).unwrap();
479 let bytes = vec![2u8; 1];
480 let padded = AvidMScheme::pad_to_fields(¶m, &bytes);
481 assert_eq!(padded.len(), 2usize);
482 assert_eq!(padded, [F::from(2u32 + u8::MAX as u32 + 1), F::from(0)]);
483
484 let bytes = vec![2u8; elem_bytes_len * 2];
485 let padded = AvidMScheme::pad_to_fields(¶m, &bytes);
486 assert_eq!(padded.len(), 4usize);
487 }
488
489 #[test]
490 fn round_trip() {
491 let params_list = [(2, 4), (3, 9), (5, 6), (15, 16)];
493 let payload_byte_lens = [1, 31, 32, 500];
494
495 let mut rng = jf_utils::test_rng();
498
499 for (recovery_threshold, num_storage_nodes) in params_list {
500 let weights: Vec<u32> = (0..num_storage_nodes)
501 .map(|_| rng.next_u32() % 5 + 1)
502 .collect();
503 let total_weights: u32 = weights.iter().sum();
504 let params = AvidMScheme::setup(recovery_threshold, total_weights as usize).unwrap();
505
506 for payload_byte_len in payload_byte_lens {
507 println!(
508 "recovery_threshold:: {recovery_threshold} num_storage_nodes: \
509 {num_storage_nodes} payload_byte_len: {payload_byte_len}"
510 );
511 println!("weights: {weights:?}");
512
513 let payload = {
514 let mut bytes_random = vec![0u8; payload_byte_len];
515 rng.fill_bytes(&mut bytes_random);
516 bytes_random
517 };
518
519 let (commit, mut shares) =
520 AvidMScheme::disperse(¶ms, &weights, &payload).unwrap();
521
522 assert_eq!(shares.len(), num_storage_nodes);
523
524 shares.iter().for_each(|share| {
526 assert!(
527 AvidMScheme::verify_share(¶ms, &commit, share).is_ok_and(|r| r.is_ok())
528 )
529 });
530
531 shares.shuffle(&mut rng);
533 let mut cumulated_weights = 0;
534 let mut cut_index = 0;
535 while cumulated_weights <= recovery_threshold {
536 cumulated_weights += shares[cut_index].content.range.len();
537 cut_index += 1;
538 }
539 let payload_recovered =
540 AvidMScheme::recover(¶ms, &commit, &shares[..cut_index]).unwrap();
541 assert_eq!(payload_recovered, payload);
542 }
543 }
544 }
545
546 #[test]
547 #[cfg(feature = "print-trace")]
548 fn round_trip_breakdown() {
549 use ark_std::{end_timer, start_timer};
550
551 let mut rng = jf_utils::test_rng();
552
553 let params = AvidMScheme::setup(50usize, 200usize).unwrap();
554 let weights = vec![2u32; 100usize];
555 let payload_byte_len = 1024 * 1024 * 32; let payload = {
558 let mut bytes_random = vec![0u8; payload_byte_len];
559 rng.fill_bytes(&mut bytes_random);
560 bytes_random
561 };
562
563 let disperse_timer = start_timer!(|| format!("Disperse {} bytes", payload_byte_len));
564 let (commit, shares) = AvidMScheme::disperse(¶ms, &weights, &payload).unwrap();
565 end_timer!(disperse_timer);
566
567 let recover_timer = start_timer!(|| "Recovery");
568 AvidMScheme::recover(¶ms, &commit, &shares).unwrap();
569 end_timer!(recover_timer);
570 }
571}