Skip to content

Commit

Permalink
Add Field::sum_of_products method
Browse files Browse the repository at this point in the history
Closes #79.
  • Loading branch information
str4d committed Apr 27, 2022
1 parent dcd0219 commit be49a7c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ and this library adheres to Rust's notion of
[Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `ff::Field::sum_of_products`

## [0.11.0] - 2021-09-02
### Added
Expand Down
15 changes: 15 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ pub trait Field:

res
}

/// Returns `pairs.into_iter().fold(Self::zero(), |acc, (a_i, b_i)| acc + a_i * b_i)`.
///
/// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length
/// sequences of elements `a` and `b`, such that `pairs = a.zip(b)`.
///
/// The provided implementation of this trait method uses the direct calculation given
/// above. Implementations of `Field` should override this to use more efficient
/// methods that take advantage of their internal representation, such as interleaving
/// or sharing modular reductions.
fn sum_of_products<'a, I: IntoIterator<Item = (&'a Self, &'a Self)> + Clone>(pairs: I) -> Self {
pairs
.into_iter()
.fold(Self::zero(), |acc, (a_i, b_i)| acc + (*a_i * b_i))
}
}

/// This represents an element of a prime field.
Expand Down
41 changes: 41 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,47 @@ mod fermat {
struct Fermat65537Field([u64; 1]);
}

#[test]
fn sum_of_products() {
use ff::{Field, PrimeField};

let one = Bls381K12Scalar::one();

// [1, 2, 3, 4]
let values: Vec<_> = (0..4)
.scan(one, |acc, _| {
let ret = *acc;
*acc += &one;
Some(ret)
})
.collect();

// We'll pair each value with itself.
let expected = Bls381K12Scalar::from_str_vartime("30").unwrap();

// Check that we can produce the necessary input from two iterators.
assert_eq!(
// Directly produces (&v, &v)
Bls381K12Scalar::sum_of_products(values.iter().zip(values.iter())),
expected,
);

// Check that we can produce the necessary input from an iterator of values.
assert_eq!(
// Maps &v to (&v, &v)
Bls381K12Scalar::sum_of_products(values.iter().map(|v| (v, v))),
expected,
);

// Check that we can produce the necessary input from an iterator of tuples.
let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect();
assert_eq!(
// Maps &(a, b) to (&a, &b)
Bls381K12Scalar::sum_of_products(tuples.iter().map(|(a, b)| (a, b))),
expected,
);
}

#[test]
fn batch_inversion() {
use ff::{BatchInverter, Field};
Expand Down

0 comments on commit be49a7c

Please sign in to comment.